diff --git a/converter/composite_data_converter.go b/converter/composite_data_converter.go index 41ebbb132..8f06e18f6 100644 --- a/converter/composite_data_converter.go +++ b/converter/composite_data_converter.go @@ -42,7 +42,7 @@ type ( // Order is important here because during serialization DataConverter will try PayloadsConverters in // that order until PayloadConverter returns non nil payload. // Last PayloadConverter should always serialize the value (JSONPayloadConverter is good candidate for it). -func NewCompositeDataConverter(payloadConverters ...PayloadConverter) *CompositeDataConverter { +func NewCompositeDataConverter(payloadConverters ...PayloadConverter) DataConverter { dc := &CompositeDataConverter{ payloadConverters: make(map[string]PayloadConverter, len(payloadConverters)), orderedEncodings: make([]string, len(payloadConverters)), diff --git a/converter/default_data_converter.go b/converter/default_data_converter.go index b0096be02..0b8d9a06a 100644 --- a/converter/default_data_converter.go +++ b/converter/default_data_converter.go @@ -41,6 +41,6 @@ var ( ) // GetDefaultDataConverter returns default data converter used by Temporal worker. -func GetDefaultDataConverter() *CompositeDataConverter { +func GetDefaultDataConverter() DataConverter { return defaultDataConverter } diff --git a/internal/headers.go b/internal/headers.go index cac64491c..02b535c5e 100644 --- a/internal/headers.go +++ b/internal/headers.go @@ -27,41 +27,53 @@ package internal import ( "context" + "go.temporal.io/sdk/converter" + commonpb "go.temporal.io/api/common/v1" ) // HeaderWriter is an interface to write information to temporal headers -type HeaderWriter interface { - Set(string, *commonpb.Payload) -} +type ( + HeaderWriter interface { + Set(string, *commonpb.Payload) + } -// HeaderReader is an interface to read information from temporal headers -type HeaderReader interface { - Get(string) (*commonpb.Payload, bool) - ForEachKey(handler func(string, *commonpb.Payload) error) error -} + // HeaderReader is an interface to read information from temporal headers + HeaderReader interface { + Get(string) (*commonpb.Payload, bool) + ForEachKey(handler func(string, *commonpb.Payload) error) error + } -// ContextPropagator is an interface that determines what information from -// context to pass along -type ContextPropagator interface { - // Inject injects information from a Go Context into headers - Inject(context.Context, HeaderWriter) error + // ContextPropagator is an interface that determines what information from + // context to pass along + ContextPropagator interface { + // Inject injects information from a Go Context into headers + Inject(context.Context, HeaderWriter) error - // Extract extracts context information from headers and returns a context - // object - Extract(context.Context, HeaderReader) (context.Context, error) + // Extract extracts context information from headers and returns a context + // object + Extract(context.Context, HeaderReader) (context.Context, error) - // InjectFromWorkflow injects information from workflow context into headers - InjectFromWorkflow(Context, HeaderWriter) error + // InjectFromWorkflow injects information from workflow context into headers + InjectFromWorkflow(Context, HeaderWriter) error - // ExtractToWorkflow extracts context information from headers and returns - // a workflow context - ExtractToWorkflow(Context, HeaderReader) (Context, error) -} + // ExtractToWorkflow extracts context information from headers and returns + // a workflow context + ExtractToWorkflow(Context, HeaderReader) (Context, error) + } -type headerReader struct { - header *commonpb.Header -} + // ContextAware is an optional interface that can be implemented alongside DataConverter. + // This interface allows Temporal to pass Workflow/Activity contexts to the DataConverter + // so that it may tailor it's behaviour. + ContextAware interface { + WithWorkflowContext(ctx Context) converter.DataConverter + WithContext(ctx context.Context) converter.DataConverter + } + + headerReader struct { + header *commonpb.Header + } +) func (hr *headerReader) ForEachKey(handler func(string, *commonpb.Payload) error) error { if hr.header == nil { @@ -106,3 +118,24 @@ func NewHeaderWriter(header *commonpb.Header) HeaderWriter { } return &headerWriter{header: header} } + +// WithWorkflowContext returns a new DataConverter tailored to the passed Workflow context if +// the DataConverter implements the ContextAware interface. Otherwise the DataConverter is returned +// as-is. +func WithWorkflowContext(ctx Context, dc converter.DataConverter) converter.DataConverter { + if d, ok := dc.(ContextAware); ok { + return d.WithWorkflowContext(ctx) + } + return dc +} + +// WithContext returns a new DataConverter tailored to the passed Workflow/Activity context if +// the DataConverter implements the ContextAware interface. Otherwise the DataConverter is returned +// as-is. This is generally used for Activity context but can be context for a Workflow if we're +// not yet executing the workflow so do not have a workflow.Context. +func WithContext(ctx context.Context, dc converter.DataConverter) converter.DataConverter { + if d, ok := dc.(ContextAware); ok { + return d.WithContext(ctx) + } + return dc +} diff --git a/internal/internal_worker.go b/internal/internal_worker.go index 3674fdb4a..f8897c020 100644 --- a/internal/internal_worker.go +++ b/internal/internal_worker.go @@ -746,7 +746,7 @@ type workflowExecutor struct { func (we *workflowExecutor) Execute(ctx Context, input *commonpb.Payloads) (*commonpb.Payloads, error) { var args []interface{} - dataConverter := getWorkflowEnvOptions(ctx).DataConverter + dataConverter := WithWorkflowContext(ctx, getWorkflowEnvOptions(ctx).DataConverter) fnType := reflect.TypeOf(we.fn) decoded, err := decodeArgsToValues(dataConverter, fnType, input) @@ -832,11 +832,15 @@ func (ae *activityExecutor) executeWithActualArgsWithoutParseResult(ctx context. } func getDataConverterFromActivityCtx(ctx context.Context) converter.DataConverter { + var dataConverter converter.DataConverter + env := getActivityEnvironmentFromCtx(ctx) - if env == nil || env.dataConverter == nil { - return converter.GetDefaultDataConverter() + if env != nil && env.dataConverter != nil { + dataConverter = env.dataConverter + } else { + dataConverter = converter.GetDefaultDataConverter() } - return env.dataConverter + return WithContext(ctx, dataConverter) } func getNamespaceFromActivityCtx(ctx context.Context) string { diff --git a/internal/internal_workflow.go b/internal/internal_workflow.go index 0ed168f8a..59c206938 100644 --- a/internal/internal_workflow.go +++ b/internal/internal_workflow.go @@ -1238,10 +1238,15 @@ func setWorkflowEnvOptionsIfNotExist(ctx Context) Context { func getDataConverterFromWorkflowContext(ctx Context) converter.DataConverter { options := getWorkflowEnvOptions(ctx) - if options == nil || options.DataConverter == nil { - return converter.GetDefaultDataConverter() + var dataConverter converter.DataConverter + + if options != nil && options.DataConverter != nil { + dataConverter = options.DataConverter + } else { + dataConverter = converter.GetDefaultDataConverter() } - return options.DataConverter + + return WithWorkflowContext(ctx, dataConverter) } func getRegistryFromWorkflowContext(ctx Context) *registry { diff --git a/internal/internal_workflow_client.go b/internal/internal_workflow_client.go index 2466e5e3e..df582433d 100644 --- a/internal/internal_workflow_client.go +++ b/internal/internal_workflow_client.go @@ -172,8 +172,9 @@ func (wc *WorkflowClient) StartWorkflow( runTimeout := options.WorkflowRunTimeout workflowTaskTimeout := options.WorkflowTaskTimeout + dataConverter := WithContext(ctx, wc.dataConverter) // Validate type and its arguments. - workflowType, input, err := getValidatedWorkflowFunction(workflowFunc, args, wc.dataConverter, wc.registry) + workflowType, input, err := getValidatedWorkflowFunction(workflowFunc, args, dataConverter, wc.registry) if err != nil { return nil, err } diff --git a/internal/stateful_data_converter_test.go b/internal/stateful_data_converter_test.go new file mode 100644 index 000000000..b8e25e152 --- /dev/null +++ b/internal/stateful_data_converter_test.go @@ -0,0 +1,152 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package internal + +import ( + "context" + "go.temporal.io/sdk/converter" + "testing" + + "github.com/stretchr/testify/require" + commonpb "go.temporal.io/api/common/v1" +) + +type ContextAwareDataConverter struct { + dataConverter converter.DataConverter + prefix string +} + +type contextKeyType int + +const ( + prefixContextKey contextKeyType = iota +) + +func (dc *ContextAwareDataConverter) ToPayload(value interface{}) (*commonpb.Payload, error) { + return dc.dataConverter.ToPayload(value) +} + +func (dc *ContextAwareDataConverter) ToPayloads(values ...interface{}) (*commonpb.Payloads, error) { + return dc.dataConverter.ToPayloads(values) +} + +func (dc *ContextAwareDataConverter) FromPayload(payload *commonpb.Payload, valuePtr interface{}) error { + return dc.dataConverter.FromPayload(payload, valuePtr) +} + +func (dc *ContextAwareDataConverter) FromPayloads(payloads *commonpb.Payloads, valuePtrs ...interface{}) error { + return dc.dataConverter.FromPayloads(payloads, valuePtrs...) +} + +func (dc *ContextAwareDataConverter) ToString(payload *commonpb.Payload) string { + if dc.prefix != "" { + return dc.prefix + ": " + dc.dataConverter.ToString(payload) + } + + return dc.dataConverter.ToString(payload) +} + +func (dc *ContextAwareDataConverter) ToStrings(payloads *commonpb.Payloads) []string { + var result []string + for _, payload := range payloads.GetPayloads() { + result = append(result, dc.ToString(payload)) + } + + return result +} + +func (dc *ContextAwareDataConverter) WithContext(ctx context.Context) converter.DataConverter { + v := ctx.Value(prefixContextKey) + prefix, ok := v.(string) + if !ok { + return dc + } + + return &ContextAwareDataConverter{ + dataConverter: dc.dataConverter, + prefix: prefix, + } +} + +func (dc *ContextAwareDataConverter) WithWorkflowContext(ctx Context) converter.DataConverter { + v := ctx.Value(prefixContextKey) + prefix, ok := v.(string) + if !ok { + return dc + } + + return &ContextAwareDataConverter{ + dataConverter: dc.dataConverter, + prefix: prefix, + } +} + +func newContextAwareDataConverter(dataConverter converter.DataConverter) converter.DataConverter { + return &ContextAwareDataConverter{ + dataConverter: dataConverter, + } +} + +var contextAwareDataConverter = newContextAwareDataConverter(converter.GetDefaultDataConverter()) + +func TestContextAwareDataConverter(t *testing.T) { + t.Parallel() + t.Run("default", func(t *testing.T) { + t.Parallel() + payload, _ := contextAwareDataConverter.ToPayload("test") + result := contextAwareDataConverter.ToString(payload) + + require.Equal(t, `"test"`, result) + }) + t.Run("implements ContextAware", func(t *testing.T) { + t.Parallel() + _, ok := contextAwareDataConverter.(ContextAware) + require.True(t, ok) + }) + t.Run("with activity context", func(t *testing.T) { + t.Parallel() + ctx := context.Background() + ctx = context.WithValue(ctx, prefixContextKey, "testing") + + dc := WithContext(ctx, contextAwareDataConverter) + + payload, _ := dc.ToPayload("test") + result := dc.ToString(payload) + + require.Equal(t, `testing: "test"`, result) + }) + t.Run("with workflow context", func(t *testing.T) { + t.Parallel() + ctx := Background() + ctx = WithValue(ctx, prefixContextKey, "testing") + + dc := WithWorkflowContext(ctx, contextAwareDataConverter) + + payload, _ := dc.ToPayload("test") + result := dc.ToString(payload) + + require.Equal(t, `testing: "test"`, result) + }) +} diff --git a/internal/workflow.go b/internal/workflow.go index 5fa213c8f..7cf5c3cd8 100644 --- a/internal/workflow.go +++ b/internal/workflow.go @@ -679,7 +679,7 @@ func (wc *workflowEnvironmentInterceptor) ExecuteChildWorkflow(ctx Context, chil executionFuture: executionFuture.(*futureImpl), } workflowOptionsFromCtx := getWorkflowEnvOptions(ctx) - dc := workflowOptionsFromCtx.DataConverter + dc := WithWorkflowContext(ctx, workflowOptionsFromCtx.DataConverter) env := getWorkflowEnvironment(ctx) wfType, input, err := getValidatedWorkflowFunction(childWorkflowType, args, dc, env.GetRegistry()) if err != nil { diff --git a/workflow/context.go b/workflow/context.go index d8599a597..d7b20dacf 100644 --- a/workflow/context.go +++ b/workflow/context.go @@ -36,6 +36,11 @@ import ( // Context's methods may be called by multiple goroutines simultaneously. type Context = internal.Context +// ContextAware is an optional interface that can be implemented alongside DataConverter. +// This interface allows Temporal to pass Workflow/Activity contexts to the DataConverter +// so that it may tailor it's behaviour. +type ContextAware = internal.ContextAware + // ErrCanceled is the error returned by Context.Err when the context is canceled. var ErrCanceled = internal.ErrCanceled