Skip to content

Commit

Permalink
Allow workflow interceptors to add nexus headers (temporalio#1604)
Browse files Browse the repository at this point in the history
  • Loading branch information
bergundy authored Aug 20, 2024
1 parent 2a02c48 commit edc3c6c
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 20 deletions.
5 changes: 5 additions & 0 deletions interceptor/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ type HandleQueryInput = internal.HandleQueryInput
// NOTE: Experimental
type UpdateInput = internal.UpdateInput

// ExecuteNexusOperationInput is the input to WorkflowOutboundInterceptor.ExecuteNexusOperation.
//
// NOTE: Experimental
type ExecuteNexusOperationInput = internal.ExecuteNexusOperationInput

// RequestCancelNexusOperationInput is the input to WorkflowOutboundInterceptor.RequestCancelNexusOperation.
//
// NOTE: Experimental
Expand Down
21 changes: 19 additions & 2 deletions internal/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"context"
"time"

"github.com/nexus-rpc/sdk-go/nexus"
commonpb "go.temporal.io/api/common/v1"
enumspb "go.temporal.io/api/enums/v1"
updatepb "go.temporal.io/api/update/v1"
Expand Down Expand Up @@ -169,13 +170,29 @@ type HandleQueryInput struct {
Args []interface{}
}

// ExecuteNexusOperationInput is the input to WorkflowOutboundInterceptor.ExecuteNexusOperation.
//
// NOTE: Experimental
type ExecuteNexusOperationInput struct {
// Client to start the operation with.
Client NexusClient
// Operation name or OperationReference from the Nexus SDK.
Operation any
// Operation input.
Input any
// Options for starting the operation.
Options NexusOperationOptions
// Header to attach to the request.
NexusHeader nexus.Header
}

// RequestCancelNexusOperationInput is the input to WorkflowOutboundInterceptor.RequestCancelNexusOperation.
//
// NOTE: Experimental
type RequestCancelNexusOperationInput struct {
// Client that was used to start the operation.
Client NexusClient
// Operation name.
// Operation name or OperationReference from the Nexus SDK.
Operation any
// Operation ID. May be empty if the operation is synchronous or has not started yet.
ID string
Expand Down Expand Up @@ -300,7 +317,7 @@ type WorkflowOutboundInterceptor interface {
// ExecuteNexusOperation intercepts NexusClient.ExecuteOperation.
//
// NOTE: Experimental
ExecuteNexusOperation(ctx Context, client NexusClient, operation any, input any, options NexusOperationOptions) NexusOperationFuture
ExecuteNexusOperation(ctx Context, input ExecuteNexusOperationInput) NexusOperationFuture
// RequestCancelNexusOperation intercepts Nexus Operation cancelation via context.
//
// NOTE: Experimental
Expand Down
7 changes: 2 additions & 5 deletions internal/interceptor_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,9 @@ func (w *WorkflowOutboundInterceptorBase) NewContinueAsNewError(
// WorkflowOutboundInterceptor.ExecuteNexusOperation.
func (w *WorkflowOutboundInterceptorBase) ExecuteNexusOperation(
ctx Context,
client NexusClient,
operation any,
input any,
options NexusOperationOptions,
input ExecuteNexusOperationInput,
) NexusOperationFuture {
return w.Next.ExecuteNexusOperation(ctx, client, operation, input, options)
return w.Next.ExecuteNexusOperation(ctx, input)
}

// RequestCancelNexusOperation implements
Expand Down
34 changes: 21 additions & 13 deletions internal/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"strings"
"time"

"github.com/nexus-rpc/sdk-go/nexus"
"golang.org/x/exp/constraints"
"golang.org/x/exp/slices"

Expand Down Expand Up @@ -2273,37 +2274,44 @@ func (c nexusClient) Service() string {
func (c nexusClient) ExecuteOperation(ctx Context, operation any, input any, options NexusOperationOptions) NexusOperationFuture {
assertNotInReadOnlyState(ctx)
i := getWorkflowOutboundInterceptor(ctx)
return i.ExecuteNexusOperation(ctx, c, operation, input, options)
return i.ExecuteNexusOperation(ctx, ExecuteNexusOperationInput{
Client: c,
Operation: operation,
Input: input,
Options: options,
NexusHeader: nexus.Header{},
})
}

func (wc *workflowEnvironmentInterceptor) prepareNexusOperationParams(ctx Context, client NexusClient, operation any, input any, options NexusOperationOptions) (executeNexusOperationParams, error) {
func (wc *workflowEnvironmentInterceptor) prepareNexusOperationParams(ctx Context, input ExecuteNexusOperationInput) (executeNexusOperationParams, error) {
dc := WithWorkflowContext(ctx, wc.env.GetDataConverter())

var ok bool
var operationName string
if operationName, ok = operation.(string); ok {
} else if regOp, ok := operation.(interface{ Name() string }); ok {
if operationName, ok = input.Operation.(string); ok {
} else if regOp, ok := input.Operation.(interface{ Name() string }); ok {
operationName = regOp.Name()
} else {
return executeNexusOperationParams{}, fmt.Errorf("invalid 'operation' parameter, must be an OperationReference or a string")
}
// TODO(bergundy): Validate operation types against input once there's a good way to extract the generic types from
// OperationReference in the Nexus Go SDK.

payload, err := dc.ToPayload(input)
payload, err := dc.ToPayload(input.Input)
if err != nil {
return executeNexusOperationParams{}, err
}

return executeNexusOperationParams{
client: client,
operation: operationName,
input: payload,
options: options,
client: input.Client,
operation: operationName,
input: payload,
options: input.Options,
nexusHeader: input.NexusHeader,
}, nil
}

func (wc *workflowEnvironmentInterceptor) ExecuteNexusOperation(ctx Context, client NexusClient, operation any, input any, options NexusOperationOptions) NexusOperationFuture {
func (wc *workflowEnvironmentInterceptor) ExecuteNexusOperation(ctx Context, input ExecuteNexusOperationInput) NexusOperationFuture {
mainFuture, mainSettable := newDecodeFuture(ctx, nil /* this param is never used */)
executionFuture, executionSettable := NewFuture(ctx)
result := &nexusOperationFutureImpl{
Expand All @@ -2320,7 +2328,7 @@ func (wc *workflowEnvironmentInterceptor) ExecuteNexusOperation(ctx Context, cli

ctxDone, cancellable := ctx.Done().(*channelImpl)
cancellationCallback := &receiveCallback{}
params, err := wc.prepareNexusOperationParams(ctx, client, operation, input, options)
params, err := wc.prepareNexusOperationParams(ctx, input)
if err != nil {
executionSettable.Set(nil, err)
mainSettable.Set(nil, err)
Expand Down Expand Up @@ -2349,8 +2357,8 @@ func (wc *workflowEnvironmentInterceptor) ExecuteNexusOperation(ctx Context, cli
if ctx.Err() == ErrCanceled && !mainFuture.IsReady() {
// Go back to the top of the interception chain.
getWorkflowOutboundInterceptor(ctx).RequestCancelNexusOperation(ctx, RequestCancelNexusOperationInput{
Client: client,
Operation: operation,
Client: input.Client,
Operation: input.Operation,
ID: operationID,
seq: seq,
})
Expand Down
84 changes: 84 additions & 0 deletions test/nexus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"go.temporal.io/api/operatorservice/v1"

"go.temporal.io/sdk/client"
"go.temporal.io/sdk/interceptor"
"go.temporal.io/sdk/internal/common/metrics"
ilog "go.temporal.io/sdk/internal/log"
"go.temporal.io/sdk/temporal"
Expand Down Expand Up @@ -942,3 +943,86 @@ func TestWorkflowTestSuite_NexusSyncOperation_ClientMethods_Panic(t *testing.T)
require.NoError(t, env.GetWorkflowError())
require.Equal(t, "not implemented in the test environment", panicReason)
}

type nexusInterceptor struct {
interceptor.WorkerInterceptorBase
interceptor.WorkflowInboundInterceptorBase
interceptor.WorkflowOutboundInterceptorBase
}

func (i *nexusInterceptor) InterceptWorkflow(
ctx workflow.Context,
next interceptor.WorkflowInboundInterceptor,
) interceptor.WorkflowInboundInterceptor {
i.WorkflowInboundInterceptorBase.Next = next
return i
}

func (i *nexusInterceptor) Init(outbound interceptor.WorkflowOutboundInterceptor) error {
i.WorkflowOutboundInterceptorBase.Next = outbound
return i.WorkflowInboundInterceptorBase.Next.Init(i)
}

func (i *nexusInterceptor) ExecuteNexusOperation(
ctx workflow.Context,
input interceptor.ExecuteNexusOperationInput,
) workflow.NexusOperationFuture {
input.NexusHeader["test"] = "present"
return i.WorkflowOutboundInterceptorBase.Next.ExecuteNexusOperation(ctx, input)
}

func TestInterceptors(t *testing.T) {
if os.Getenv("DISABLE_NEXUS_TESTS") != "" {
t.SkipNow()
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
tc := newTestContext(t, ctx)

op := temporalnexus.NewSyncOperation("op", func(ctx context.Context, c client.Client, _ nexus.NoValue, opts nexus.StartOperationOptions) (string, error) {
return opts.Header["test"], nil
})

wf := func(ctx workflow.Context) error {
c := workflow.NewNexusClient(tc.endpoint, "test")
fut := c.ExecuteOperation(ctx, op, nil, workflow.NexusOperationOptions{})
var res string

var exec workflow.NexusOperationExecution
if err := fut.GetNexusOperationExecution().Get(ctx, &exec); err != nil {
return fmt.Errorf("expected start to succeed: %w", err)
}
if exec.OperationID != "" {
return fmt.Errorf("expected empty operation ID")
}
if err := fut.Get(ctx, &res); err != nil {
return err
}
// If the operation didn't fail the only expected result is "present" (header value injected by the interceptor).
if res != "present" {
return fmt.Errorf("unexpected result: %v", res)
}
return nil
}

w := worker.New(tc.client, tc.taskQueue, worker.Options{
Interceptors: []interceptor.WorkerInterceptor{
&nexusInterceptor{},
},
})
service := nexus.NewService("test")
require.NoError(t, service.Register(op))
w.RegisterNexusService(service)
w.RegisterWorkflow(wf)
require.NoError(t, w.Start())
t.Cleanup(w.Stop)

run, err := tc.client.ExecuteWorkflow(ctx, client.StartWorkflowOptions{
TaskQueue: tc.taskQueue,
// The endpoint registry may take a bit to propagate to the history service, use a shorter workflow task
// timeout to speed up the attempts.
WorkflowTaskTimeout: time.Second,
}, wf)
require.NoError(t, err)
require.NoError(t, run.Get(ctx, nil))
}

0 comments on commit edc3c6c

Please sign in to comment.