diff --git a/cmd/proxygenerator/interceptor.go b/cmd/proxygenerator/interceptor.go index f31286b9..e488d4ae 100644 --- a/cmd/proxygenerator/interceptor.go +++ b/cmd/proxygenerator/interceptor.go @@ -207,15 +207,10 @@ func visitPayloads( parent proto.Message, objs ...interface{}, ) error { - for i, obj := range objs { + for _, obj := range objs { ctx.SinglePayloadRequired = false switch o := obj.(type) { - case *common.Payload: - if o == nil { continue } - no, err := visitPayload(ctx, options, parent, o) - if err != nil { return err } - objs[i] = no case map[string]*common.Payload: for ix, x := range o { if nx, err := visitPayload(ctx, options, parent, x); err != nil { @@ -273,6 +268,14 @@ func visitPayloads( if options.SkipSearchAttributes { continue } {{end}} if o == nil { continue } + {{range $record.Payloads -}} + if o.{{.}} != nil { + no, err := visitPayload(ctx, options, o, o.{{.}}) + if err != nil { return err } + o.{{.}} = no + } + {{end}} + {{if $record.Methods}} if err := visitPayloads( ctx, options, @@ -281,6 +284,7 @@ func visitPayloads( o.{{.}}(), {{end}} ); err != nil { return err } + {{end}} {{end}} } } @@ -326,10 +330,11 @@ var interceptorTemplate = template.Must(template.New("interceptor").Parse(Interc // TypeRecord holds the state for a type referred to by the workflow service type TypeRecord struct { - Methods []string // List of methods on this type that can eventually lead to Payload(s) - Slice bool // The API refers to slices of this type - Map bool // The API refers to maps with this type as the value - Matches bool // We found methods on this type that can eventually lead to Payload(s) + Methods []string // List of methods on this type that can eventually lead to Payload(s) + Payloads []string // List of attributes on this type that are of type Payload + Slice bool // The API refers to slices of this type + Map bool // The API refers to maps with this type as the value + Matches bool // We found methods on this type that can eventually lead to Payload(s) } // isSlice returns true if a type is slice, false otherwise @@ -411,12 +416,12 @@ func pruneRecords(input map[string]*TypeRecord) map[string]*TypeRecord { // walk iterates the methods on a type and returns whether any of them can eventually lead to Payload(s) // The return type for each method on this type is walked recursively to decide which methods can lead to Payload(s) -func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord) bool { +func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord, checkDirectPayload bool) bool { typeName := typeName(typ) // If this type is a slice then walk the underlying type and then make a note we need to encode slices of this type if isSlice(typ) { - result := walk(desired, elemType(typ), records) + result := walk(desired, elemType(typ), records, checkDirectPayload) if result { record := (*records)[typeName] record.Slice = true @@ -426,7 +431,7 @@ func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord) // If this type is a map then walk the underlying type and then make a note we need to encode maps with values of this type if isMap(typ) { - result := walk(desired, elemType(typ), records) + result := walk(desired, elemType(typ), records, checkDirectPayload) if result { record := (*records)[typeName] record.Map = true @@ -459,8 +464,18 @@ func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord) // All the Get... methods return the relevant protobuf as the first result resultType := sig.Results().At(0).Type() + if checkDirectPayload && resultType.String() == "*go.temporal.io/api/common/v1.Payload" { + record.Matches = true + prefix, ok := strings.CutPrefix(methodName, "Get") + if !ok { + panic(fmt.Errorf("expected method to have a Get prefix: %s", methodName)) + } + record.Payloads = append(record.Payloads, prefix) + continue + } + // Check if this method returns a Payload(s) or if it leads (eventually) to a Type which refers to a Payload(s) - if typeMatches(resultType, desired...) || walk(desired, resultType, records) { + if typeMatches(resultType, desired...) || walk(desired, resultType, records, checkDirectPayload) { record.Matches = true record.Methods = append(record.Methods, methodName) } @@ -536,10 +551,10 @@ func generateInterceptor(cfg config) error { } sig := meth.Obj().Type().(*types.Signature) - walk(payloadTypes, sig.Params().At(1).Type(), &payloadRecords) - walk(failureTypes, sig.Params().At(1).Type(), &failureRecords) - walk(payloadTypes, sig.Results().At(0).Type(), &payloadRecords) - walk(failureTypes, sig.Results().At(0).Type(), &failureRecords) + walk(payloadTypes, sig.Params().At(1).Type(), &payloadRecords, true) + walk(failureTypes, sig.Params().At(1).Type(), &failureRecords, false) + walk(payloadTypes, sig.Results().At(0).Type(), &payloadRecords, true) + walk(failureTypes, sig.Results().At(0).Type(), &failureRecords, false) } for _, meth := range typeutil.IntuitiveMethodSet(operatorService, nil) { @@ -548,14 +563,14 @@ func generateInterceptor(cfg config) error { } sig := meth.Obj().Type().(*types.Signature) - walk(payloadTypes, sig.Params().At(1).Type(), &payloadRecords) - walk(failureTypes, sig.Params().At(1).Type(), &failureRecords) - walk(payloadTypes, sig.Results().At(0).Type(), &payloadRecords) - walk(failureTypes, sig.Results().At(0).Type(), &failureRecords) + walk(payloadTypes, sig.Params().At(1).Type(), &payloadRecords, true) + walk(failureTypes, sig.Params().At(1).Type(), &failureRecords, false) + walk(payloadTypes, sig.Results().At(0).Type(), &payloadRecords, true) + walk(failureTypes, sig.Results().At(0).Type(), &failureRecords, false) } - walk(payloadTypes, workflowExecutions, &payloadRecords) - walk(failureTypes, workflowExecutions, &failureRecords) + walk(payloadTypes, workflowExecutions, &payloadRecords, true) + walk(failureTypes, workflowExecutions, &failureRecords, false) payloadRecords = pruneRecords(payloadRecords) failureRecords = pruneRecords(failureRecords) diff --git a/proxy/interceptor.go b/proxy/interceptor.go index f9c0e652..c21a8c5c 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -205,19 +205,10 @@ func visitPayloads( parent proto.Message, objs ...interface{}, ) error { - for i, obj := range objs { + for _, obj := range objs { ctx.SinglePayloadRequired = false switch o := obj.(type) { - case *common.Payload: - if o == nil { - continue - } - no, err := visitPayload(ctx, options, parent, o) - if err != nil { - return err - } - objs[i] = no case map[string]*common.Payload: for ix, x := range o { if nx, err := visitPayload(ctx, options, parent, x); err != nil { @@ -263,6 +254,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -278,6 +270,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -292,6 +285,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -313,6 +307,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -338,6 +333,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -352,6 +348,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -371,6 +368,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -385,6 +383,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -399,6 +398,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -415,6 +415,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -430,13 +431,12 @@ func visitPayloads( if o == nil { continue } - if err := visitPayloads( - ctx, - options, - o, - o.GetInput(), - ); err != nil { - return err + if o.Input != nil { + no, err := visitPayload(ctx, options, o, o.Input) + if err != nil { + return err + } + o.Input = no } case *command.SignalExternalWorkflowExecutionCommandAttributes: @@ -444,6 +444,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -459,6 +460,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -476,6 +478,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -490,6 +493,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -504,6 +508,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -522,6 +527,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -536,6 +542,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -550,6 +557,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -571,6 +579,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -585,6 +594,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -599,6 +609,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -613,6 +624,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -634,6 +646,14 @@ func visitPayloads( if o == nil { continue } + if o.EncodedAttributes != nil { + no, err := visitPayload(ctx, options, o, o.EncodedAttributes) + if err != nil { + return err + } + o.EncodedAttributes = no + } + if err := visitPayloads( ctx, options, @@ -641,7 +661,6 @@ func visitPayloads( o.GetApplicationFailureInfo(), o.GetCanceledFailureInfo(), o.GetCause(), - o.GetEncodedAttributes(), o.GetResetWorkflowFailureInfo(), o.GetTimeoutFailureInfo(), ); err != nil { @@ -653,6 +672,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -667,6 +687,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -681,6 +702,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -695,6 +717,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -709,6 +732,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -723,6 +747,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -738,6 +763,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -752,6 +778,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -766,6 +793,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -780,6 +808,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -794,6 +823,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -808,6 +838,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -822,6 +853,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -843,6 +875,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -890,6 +923,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -906,6 +940,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -920,13 +955,12 @@ func visitPayloads( if o == nil { continue } - if err := visitPayloads( - ctx, - options, - o, - o.GetResult(), - ); err != nil { - return err + if o.Result != nil { + no, err := visitPayload(ctx, options, o, o.Result) + if err != nil { + return err + } + o.Result = no } case *history.NexusOperationFailedEventAttributes: @@ -934,6 +968,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -948,13 +983,12 @@ func visitPayloads( if o == nil { continue } - if err := visitPayloads( - ctx, - options, - o, - o.GetInput(), - ); err != nil { - return err + if o.Input != nil { + no, err := visitPayload(ctx, options, o, o.Input) + if err != nil { + return err + } + o.Input = no } case *history.NexusOperationTimedOutEventAttributes: @@ -962,6 +996,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -976,6 +1011,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -991,6 +1027,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1008,6 +1045,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1022,6 +1060,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1036,6 +1075,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1050,6 +1090,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1069,6 +1110,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1083,6 +1125,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1098,6 +1141,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1117,6 +1161,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1131,6 +1176,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1145,6 +1191,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1159,6 +1206,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1173,6 +1221,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1188,6 +1237,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1202,6 +1252,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1216,6 +1267,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1237,6 +1289,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1251,13 +1304,12 @@ func visitPayloads( if o == nil { continue } - if err := visitPayloads( - ctx, - options, - o, - o.GetDescription(), - ); err != nil { - return err + if o.Description != nil { + no, err := visitPayload(ctx, options, o, o.Description) + if err != nil { + return err + } + o.Description = no } case *nexus.Request: @@ -1265,6 +1317,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1279,6 +1332,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1293,13 +1347,12 @@ func visitPayloads( if o == nil { continue } - if err := visitPayloads( - ctx, - options, - o, - o.GetPayload(), - ); err != nil { - return err + if o.Payload != nil { + no, err := visitPayload(ctx, options, o, o.Payload) + if err != nil { + return err + } + o.Payload = no } case *nexus.StartOperationResponse: @@ -1307,6 +1360,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1321,13 +1375,12 @@ func visitPayloads( if o == nil { continue } - if err := visitPayloads( - ctx, - options, - o, - o.GetPayload(), - ); err != nil { - return err + if o.Payload != nil { + no, err := visitPayload(ctx, options, o, o.Payload) + if err != nil { + return err + } + o.Payload = no } case *operatorservice.CreateNexusEndpointRequest: @@ -1335,6 +1388,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1349,6 +1403,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1363,6 +1418,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1377,6 +1433,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1391,6 +1448,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1405,6 +1463,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1426,6 +1485,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1447,6 +1507,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1469,6 +1530,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1484,6 +1546,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1498,6 +1561,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1519,6 +1583,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1534,14 +1599,19 @@ func visitPayloads( if o == nil { continue } - if err := visitPayloads( - ctx, - options, - o, - o.GetDetails(), - o.GetSummary(), - ); err != nil { - return err + if o.Details != nil { + no, err := visitPayload(ctx, options, o, o.Details) + if err != nil { + return err + } + o.Details = no + } + if o.Summary != nil { + no, err := visitPayload(ctx, options, o, o.Summary) + if err != nil { + return err + } + o.Summary = no } case *update.Input: @@ -1549,6 +1619,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1564,6 +1635,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1579,6 +1651,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1600,6 +1673,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1614,6 +1688,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1632,6 +1707,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1653,6 +1729,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1675,6 +1752,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1690,6 +1768,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1711,6 +1790,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1726,6 +1806,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1747,6 +1828,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1761,6 +1843,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1777,6 +1860,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1791,6 +1875,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1807,6 +1892,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1825,6 +1911,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1846,6 +1933,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1861,6 +1949,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1882,6 +1971,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1897,6 +1987,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1911,6 +2002,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1925,6 +2017,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1939,6 +2032,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1953,6 +2047,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1967,6 +2062,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1981,6 +2077,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -1995,6 +2092,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2009,6 +2107,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2030,6 +2129,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2046,6 +2146,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2060,6 +2161,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2074,6 +2176,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2091,6 +2194,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2105,6 +2209,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2119,6 +2224,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2133,6 +2239,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2147,6 +2254,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2161,6 +2269,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2175,6 +2284,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2189,6 +2299,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2203,6 +2314,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2218,6 +2330,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2232,6 +2345,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2247,6 +2361,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2261,6 +2376,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2275,6 +2391,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2290,6 +2407,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2306,6 +2424,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2321,6 +2440,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2336,6 +2456,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2350,6 +2471,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2364,6 +2486,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2379,6 +2502,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2398,6 +2522,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2413,6 +2538,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2428,6 +2554,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2448,6 +2575,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2462,6 +2590,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2476,6 +2605,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2491,6 +2621,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, @@ -2505,6 +2636,7 @@ func visitPayloads( if o == nil { continue } + if err := visitPayloads( ctx, options, diff --git a/proxy/interceptor_test.go b/proxy/interceptor_test.go index 67536ced..1c28042d 100644 --- a/proxy/interceptor_test.go +++ b/proxy/interceptor_test.go @@ -120,6 +120,27 @@ func TestVisitPayloads(t *testing.T) { }, ) require.NoError(err) + + msg := &history.HistoryEvent{ + Attributes: &history.HistoryEvent_NexusOperationScheduledEventAttributes{ + NexusOperationScheduledEventAttributes: &history.NexusOperationScheduledEventAttributes{ + Input: inputPayload(), + }, + }, + } + err = VisitPayloads( + context.Background(), + msg, + VisitPayloadsOptions{ + Visitor: func(vpc *VisitPayloadsContext, p []*common.Payload) ([]*common.Payload, error) { + require.True(vpc.SinglePayloadRequired) + require.Equal([]byte("test"), p[0].Data) + return []*common.Payload{{Data: []byte("visited")}}, nil + }, + }, + ) + require.Equal([]byte("visited"), msg.GetNexusOperationScheduledEventAttributes().Input.Data) + require.NoError(err) } func TestVisitPayloads_NestedParent(t *testing.T) {