diff --git a/interceptor/interceptor.go b/interceptor/interceptor.go index 6061e88b1..aa269ca5c 100644 --- a/interceptor/interceptor.go +++ b/interceptor/interceptor.go @@ -131,6 +131,11 @@ type HandleQueryInput = internal.HandleQueryInput // NOTE: Experimental type UpdateInput = internal.UpdateInput +// ExecuteNexusOperationInput is the input to WorkflowOutboundInterceptor.ExecuteNexusOperation. +// +// NOTE: Experimental +type ExecuteNexusOperationInput = internal.ExecuteNexusOperationInput + // RequestCancelNexusOperationInput is the input to WorkflowOutboundInterceptor.RequestCancelNexusOperation. // // NOTE: Experimental diff --git a/internal/interceptor.go b/internal/interceptor.go index 51f64e989..d2a0accc7 100644 --- a/internal/interceptor.go +++ b/internal/interceptor.go @@ -26,6 +26,7 @@ import ( "context" "time" + "github.com/nexus-rpc/sdk-go/nexus" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" updatepb "go.temporal.io/api/update/v1" @@ -169,13 +170,29 @@ type HandleQueryInput struct { Args []interface{} } +// ExecuteNexusOperationInput is the input to WorkflowOutboundInterceptor.ExecuteNexusOperation. +// +// NOTE: Experimental +type ExecuteNexusOperationInput struct { + // Client to start the operation with. + Client NexusClient + // Operation name or OperationReference from the Nexus SDK. + Operation any + // Operation input. + Input any + // Options for starting the operation. + Options NexusOperationOptions + // Header to attach to the request. + NexusHeader nexus.Header +} + // RequestCancelNexusOperationInput is the input to WorkflowOutboundInterceptor.RequestCancelNexusOperation. // // NOTE: Experimental type RequestCancelNexusOperationInput struct { // Client that was used to start the operation. Client NexusClient - // Operation name. + // Operation name or OperationReference from the Nexus SDK. Operation any // Operation ID. May be empty if the operation is synchronous or has not started yet. ID string @@ -300,7 +317,7 @@ type WorkflowOutboundInterceptor interface { // ExecuteNexusOperation intercepts NexusClient.ExecuteOperation. // // NOTE: Experimental - ExecuteNexusOperation(ctx Context, client NexusClient, operation any, input any, options NexusOperationOptions) NexusOperationFuture + ExecuteNexusOperation(ctx Context, input ExecuteNexusOperationInput) NexusOperationFuture // RequestCancelNexusOperation intercepts Nexus Operation cancelation via context. // // NOTE: Experimental diff --git a/internal/interceptor_base.go b/internal/interceptor_base.go index f000f0319..7fd5b20c1 100644 --- a/internal/interceptor_base.go +++ b/internal/interceptor_base.go @@ -393,12 +393,9 @@ func (w *WorkflowOutboundInterceptorBase) NewContinueAsNewError( // WorkflowOutboundInterceptor.ExecuteNexusOperation. func (w *WorkflowOutboundInterceptorBase) ExecuteNexusOperation( ctx Context, - client NexusClient, - operation any, - input any, - options NexusOperationOptions, + input ExecuteNexusOperationInput, ) NexusOperationFuture { - return w.Next.ExecuteNexusOperation(ctx, client, operation, input, options) + return w.Next.ExecuteNexusOperation(ctx, input) } // RequestCancelNexusOperation implements diff --git a/internal/workflow.go b/internal/workflow.go index 7c1eb676a..f5e8f2604 100644 --- a/internal/workflow.go +++ b/internal/workflow.go @@ -31,6 +31,7 @@ import ( "strings" "time" + "github.com/nexus-rpc/sdk-go/nexus" "golang.org/x/exp/constraints" "golang.org/x/exp/slices" @@ -2273,16 +2274,22 @@ func (c nexusClient) Service() string { func (c nexusClient) ExecuteOperation(ctx Context, operation any, input any, options NexusOperationOptions) NexusOperationFuture { assertNotInReadOnlyState(ctx) i := getWorkflowOutboundInterceptor(ctx) - return i.ExecuteNexusOperation(ctx, c, operation, input, options) + return i.ExecuteNexusOperation(ctx, ExecuteNexusOperationInput{ + Client: c, + Operation: operation, + Input: input, + Options: options, + NexusHeader: nexus.Header{}, + }) } -func (wc *workflowEnvironmentInterceptor) prepareNexusOperationParams(ctx Context, client NexusClient, operation any, input any, options NexusOperationOptions) (executeNexusOperationParams, error) { +func (wc *workflowEnvironmentInterceptor) prepareNexusOperationParams(ctx Context, input ExecuteNexusOperationInput) (executeNexusOperationParams, error) { dc := WithWorkflowContext(ctx, wc.env.GetDataConverter()) var ok bool var operationName string - if operationName, ok = operation.(string); ok { - } else if regOp, ok := operation.(interface{ Name() string }); ok { + if operationName, ok = input.Operation.(string); ok { + } else if regOp, ok := input.Operation.(interface{ Name() string }); ok { operationName = regOp.Name() } else { return executeNexusOperationParams{}, fmt.Errorf("invalid 'operation' parameter, must be an OperationReference or a string") @@ -2290,20 +2297,21 @@ func (wc *workflowEnvironmentInterceptor) prepareNexusOperationParams(ctx Contex // TODO(bergundy): Validate operation types against input once there's a good way to extract the generic types from // OperationReference in the Nexus Go SDK. - payload, err := dc.ToPayload(input) + payload, err := dc.ToPayload(input.Input) if err != nil { return executeNexusOperationParams{}, err } return executeNexusOperationParams{ - client: client, - operation: operationName, - input: payload, - options: options, + client: input.Client, + operation: operationName, + input: payload, + options: input.Options, + nexusHeader: input.NexusHeader, }, nil } -func (wc *workflowEnvironmentInterceptor) ExecuteNexusOperation(ctx Context, client NexusClient, operation any, input any, options NexusOperationOptions) NexusOperationFuture { +func (wc *workflowEnvironmentInterceptor) ExecuteNexusOperation(ctx Context, input ExecuteNexusOperationInput) NexusOperationFuture { mainFuture, mainSettable := newDecodeFuture(ctx, nil /* this param is never used */) executionFuture, executionSettable := NewFuture(ctx) result := &nexusOperationFutureImpl{ @@ -2320,7 +2328,7 @@ func (wc *workflowEnvironmentInterceptor) ExecuteNexusOperation(ctx Context, cli ctxDone, cancellable := ctx.Done().(*channelImpl) cancellationCallback := &receiveCallback{} - params, err := wc.prepareNexusOperationParams(ctx, client, operation, input, options) + params, err := wc.prepareNexusOperationParams(ctx, input) if err != nil { executionSettable.Set(nil, err) mainSettable.Set(nil, err) @@ -2349,8 +2357,8 @@ func (wc *workflowEnvironmentInterceptor) ExecuteNexusOperation(ctx Context, cli if ctx.Err() == ErrCanceled && !mainFuture.IsReady() { // Go back to the top of the interception chain. getWorkflowOutboundInterceptor(ctx).RequestCancelNexusOperation(ctx, RequestCancelNexusOperationInput{ - Client: client, - Operation: operation, + Client: input.Client, + Operation: input.Operation, ID: operationID, seq: seq, }) diff --git a/test/nexus_test.go b/test/nexus_test.go index 530df8b96..13bb9659e 100644 --- a/test/nexus_test.go +++ b/test/nexus_test.go @@ -44,6 +44,7 @@ import ( "go.temporal.io/api/operatorservice/v1" "go.temporal.io/sdk/client" + "go.temporal.io/sdk/interceptor" "go.temporal.io/sdk/internal/common/metrics" ilog "go.temporal.io/sdk/internal/log" "go.temporal.io/sdk/temporal" @@ -942,3 +943,86 @@ func TestWorkflowTestSuite_NexusSyncOperation_ClientMethods_Panic(t *testing.T) require.NoError(t, env.GetWorkflowError()) require.Equal(t, "not implemented in the test environment", panicReason) } + +type nexusInterceptor struct { + interceptor.WorkerInterceptorBase + interceptor.WorkflowInboundInterceptorBase + interceptor.WorkflowOutboundInterceptorBase +} + +func (i *nexusInterceptor) InterceptWorkflow( + ctx workflow.Context, + next interceptor.WorkflowInboundInterceptor, +) interceptor.WorkflowInboundInterceptor { + i.WorkflowInboundInterceptorBase.Next = next + return i +} + +func (i *nexusInterceptor) Init(outbound interceptor.WorkflowOutboundInterceptor) error { + i.WorkflowOutboundInterceptorBase.Next = outbound + return i.WorkflowInboundInterceptorBase.Next.Init(i) +} + +func (i *nexusInterceptor) ExecuteNexusOperation( + ctx workflow.Context, + input interceptor.ExecuteNexusOperationInput, +) workflow.NexusOperationFuture { + input.NexusHeader["test"] = "present" + return i.WorkflowOutboundInterceptorBase.Next.ExecuteNexusOperation(ctx, input) +} + +func TestInterceptors(t *testing.T) { + if os.Getenv("DISABLE_NEXUS_TESTS") != "" { + t.SkipNow() + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + tc := newTestContext(t, ctx) + + op := temporalnexus.NewSyncOperation("op", func(ctx context.Context, c client.Client, _ nexus.NoValue, opts nexus.StartOperationOptions) (string, error) { + return opts.Header["test"], nil + }) + + wf := func(ctx workflow.Context) error { + c := workflow.NewNexusClient(tc.endpoint, "test") + fut := c.ExecuteOperation(ctx, op, nil, workflow.NexusOperationOptions{}) + var res string + + var exec workflow.NexusOperationExecution + if err := fut.GetNexusOperationExecution().Get(ctx, &exec); err != nil { + return fmt.Errorf("expected start to succeed: %w", err) + } + if exec.OperationID != "" { + return fmt.Errorf("expected empty operation ID") + } + if err := fut.Get(ctx, &res); err != nil { + return err + } + // If the operation didn't fail the only expected result is "present" (header value injected by the interceptor). + if res != "present" { + return fmt.Errorf("unexpected result: %v", res) + } + return nil + } + + w := worker.New(tc.client, tc.taskQueue, worker.Options{ + Interceptors: []interceptor.WorkerInterceptor{ + &nexusInterceptor{}, + }, + }) + service := nexus.NewService("test") + require.NoError(t, service.Register(op)) + w.RegisterNexusService(service) + w.RegisterWorkflow(wf) + require.NoError(t, w.Start()) + t.Cleanup(w.Stop) + + run, err := tc.client.ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + TaskQueue: tc.taskQueue, + // The endpoint registry may take a bit to propagate to the history service, use a shorter workflow task + // timeout to speed up the attempts. + WorkflowTaskTimeout: time.Second, + }, wf) + require.NoError(t, err) + require.NoError(t, run.Get(ctx, nil)) +}