Skip to content

Commit

Permalink
Support for context-aware data converters. (#393)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rob Holland authored Apr 1, 2021
1 parent 91db6e6 commit d46c109
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 36 deletions.
2 changes: 1 addition & 1 deletion converter/composite_data_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type (
// Order is important here because during serialization DataConverter will try PayloadsConverters in
// that order until PayloadConverter returns non nil payload.
// Last PayloadConverter should always serialize the value (JSONPayloadConverter is good candidate for it).
func NewCompositeDataConverter(payloadConverters ...PayloadConverter) *CompositeDataConverter {
func NewCompositeDataConverter(payloadConverters ...PayloadConverter) DataConverter {
dc := &CompositeDataConverter{
payloadConverters: make(map[string]PayloadConverter, len(payloadConverters)),
orderedEncodings: make([]string, len(payloadConverters)),
Expand Down
2 changes: 1 addition & 1 deletion converter/default_data_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ var (
)

// GetDefaultDataConverter returns default data converter used by Temporal worker.
func GetDefaultDataConverter() *CompositeDataConverter {
func GetDefaultDataConverter() DataConverter {
return defaultDataConverter
}
83 changes: 58 additions & 25 deletions internal/headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,41 +27,53 @@ package internal
import (
"context"

"go.temporal.io/sdk/converter"

commonpb "go.temporal.io/api/common/v1"
)

// HeaderWriter is an interface to write information to temporal headers
type HeaderWriter interface {
Set(string, *commonpb.Payload)
}
type (
HeaderWriter interface {
Set(string, *commonpb.Payload)
}

// HeaderReader is an interface to read information from temporal headers
type HeaderReader interface {
Get(string) (*commonpb.Payload, bool)
ForEachKey(handler func(string, *commonpb.Payload) error) error
}
// HeaderReader is an interface to read information from temporal headers
HeaderReader interface {
Get(string) (*commonpb.Payload, bool)
ForEachKey(handler func(string, *commonpb.Payload) error) error
}

// ContextPropagator is an interface that determines what information from
// context to pass along
type ContextPropagator interface {
// Inject injects information from a Go Context into headers
Inject(context.Context, HeaderWriter) error
// ContextPropagator is an interface that determines what information from
// context to pass along
ContextPropagator interface {
// Inject injects information from a Go Context into headers
Inject(context.Context, HeaderWriter) error

// Extract extracts context information from headers and returns a context
// object
Extract(context.Context, HeaderReader) (context.Context, error)
// Extract extracts context information from headers and returns a context
// object
Extract(context.Context, HeaderReader) (context.Context, error)

// InjectFromWorkflow injects information from workflow context into headers
InjectFromWorkflow(Context, HeaderWriter) error
// InjectFromWorkflow injects information from workflow context into headers
InjectFromWorkflow(Context, HeaderWriter) error

// ExtractToWorkflow extracts context information from headers and returns
// a workflow context
ExtractToWorkflow(Context, HeaderReader) (Context, error)
}
// ExtractToWorkflow extracts context information from headers and returns
// a workflow context
ExtractToWorkflow(Context, HeaderReader) (Context, error)
}

type headerReader struct {
header *commonpb.Header
}
// ContextAware is an optional interface that can be implemented alongside DataConverter.
// This interface allows Temporal to pass Workflow/Activity contexts to the DataConverter
// so that it may tailor it's behaviour.
ContextAware interface {
WithWorkflowContext(ctx Context) converter.DataConverter
WithContext(ctx context.Context) converter.DataConverter
}

headerReader struct {
header *commonpb.Header
}
)

func (hr *headerReader) ForEachKey(handler func(string, *commonpb.Payload) error) error {
if hr.header == nil {
Expand Down Expand Up @@ -106,3 +118,24 @@ func NewHeaderWriter(header *commonpb.Header) HeaderWriter {
}
return &headerWriter{header: header}
}

// WithWorkflowContext returns a new DataConverter tailored to the passed Workflow context if
// the DataConverter implements the ContextAware interface. Otherwise the DataConverter is returned
// as-is.
func WithWorkflowContext(ctx Context, dc converter.DataConverter) converter.DataConverter {
if d, ok := dc.(ContextAware); ok {
return d.WithWorkflowContext(ctx)
}
return dc
}

// WithContext returns a new DataConverter tailored to the passed Workflow/Activity context if
// the DataConverter implements the ContextAware interface. Otherwise the DataConverter is returned
// as-is. This is generally used for Activity context but can be context for a Workflow if we're
// not yet executing the workflow so do not have a workflow.Context.
func WithContext(ctx context.Context, dc converter.DataConverter) converter.DataConverter {
if d, ok := dc.(ContextAware); ok {
return d.WithContext(ctx)
}
return dc
}
12 changes: 8 additions & 4 deletions internal/internal_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ type workflowExecutor struct {

func (we *workflowExecutor) Execute(ctx Context, input *commonpb.Payloads) (*commonpb.Payloads, error) {
var args []interface{}
dataConverter := getWorkflowEnvOptions(ctx).DataConverter
dataConverter := WithWorkflowContext(ctx, getWorkflowEnvOptions(ctx).DataConverter)
fnType := reflect.TypeOf(we.fn)

decoded, err := decodeArgsToValues(dataConverter, fnType, input)
Expand Down Expand Up @@ -832,11 +832,15 @@ func (ae *activityExecutor) executeWithActualArgsWithoutParseResult(ctx context.
}

func getDataConverterFromActivityCtx(ctx context.Context) converter.DataConverter {
var dataConverter converter.DataConverter

env := getActivityEnvironmentFromCtx(ctx)
if env == nil || env.dataConverter == nil {
return converter.GetDefaultDataConverter()
if env != nil && env.dataConverter != nil {
dataConverter = env.dataConverter
} else {
dataConverter = converter.GetDefaultDataConverter()
}
return env.dataConverter
return WithContext(ctx, dataConverter)
}

func getNamespaceFromActivityCtx(ctx context.Context) string {
Expand Down
11 changes: 8 additions & 3 deletions internal/internal_workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -1238,10 +1238,15 @@ func setWorkflowEnvOptionsIfNotExist(ctx Context) Context {

func getDataConverterFromWorkflowContext(ctx Context) converter.DataConverter {
options := getWorkflowEnvOptions(ctx)
if options == nil || options.DataConverter == nil {
return converter.GetDefaultDataConverter()
var dataConverter converter.DataConverter

if options != nil && options.DataConverter != nil {
dataConverter = options.DataConverter
} else {
dataConverter = converter.GetDefaultDataConverter()
}
return options.DataConverter

return WithWorkflowContext(ctx, dataConverter)
}

func getRegistryFromWorkflowContext(ctx Context) *registry {
Expand Down
3 changes: 2 additions & 1 deletion internal/internal_workflow_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ func (wc *WorkflowClient) StartWorkflow(
runTimeout := options.WorkflowRunTimeout
workflowTaskTimeout := options.WorkflowTaskTimeout

dataConverter := WithContext(ctx, wc.dataConverter)
// Validate type and its arguments.
workflowType, input, err := getValidatedWorkflowFunction(workflowFunc, args, wc.dataConverter, wc.registry)
workflowType, input, err := getValidatedWorkflowFunction(workflowFunc, args, dataConverter, wc.registry)
if err != nil {
return nil, err
}
Expand Down
152 changes: 152 additions & 0 deletions internal/stateful_data_converter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package internal

import (
"context"
"go.temporal.io/sdk/converter"
"testing"

"github.com/stretchr/testify/require"
commonpb "go.temporal.io/api/common/v1"
)

type ContextAwareDataConverter struct {
dataConverter converter.DataConverter
prefix string
}

type contextKeyType int

const (
prefixContextKey contextKeyType = iota
)

func (dc *ContextAwareDataConverter) ToPayload(value interface{}) (*commonpb.Payload, error) {
return dc.dataConverter.ToPayload(value)
}

func (dc *ContextAwareDataConverter) ToPayloads(values ...interface{}) (*commonpb.Payloads, error) {
return dc.dataConverter.ToPayloads(values)
}

func (dc *ContextAwareDataConverter) FromPayload(payload *commonpb.Payload, valuePtr interface{}) error {
return dc.dataConverter.FromPayload(payload, valuePtr)
}

func (dc *ContextAwareDataConverter) FromPayloads(payloads *commonpb.Payloads, valuePtrs ...interface{}) error {
return dc.dataConverter.FromPayloads(payloads, valuePtrs...)
}

func (dc *ContextAwareDataConverter) ToString(payload *commonpb.Payload) string {
if dc.prefix != "" {
return dc.prefix + ": " + dc.dataConverter.ToString(payload)
}

return dc.dataConverter.ToString(payload)
}

func (dc *ContextAwareDataConverter) ToStrings(payloads *commonpb.Payloads) []string {
var result []string
for _, payload := range payloads.GetPayloads() {
result = append(result, dc.ToString(payload))
}

return result
}

func (dc *ContextAwareDataConverter) WithContext(ctx context.Context) converter.DataConverter {
v := ctx.Value(prefixContextKey)
prefix, ok := v.(string)
if !ok {
return dc
}

return &ContextAwareDataConverter{
dataConverter: dc.dataConverter,
prefix: prefix,
}
}

func (dc *ContextAwareDataConverter) WithWorkflowContext(ctx Context) converter.DataConverter {
v := ctx.Value(prefixContextKey)
prefix, ok := v.(string)
if !ok {
return dc
}

return &ContextAwareDataConverter{
dataConverter: dc.dataConverter,
prefix: prefix,
}
}

func newContextAwareDataConverter(dataConverter converter.DataConverter) converter.DataConverter {
return &ContextAwareDataConverter{
dataConverter: dataConverter,
}
}

var contextAwareDataConverter = newContextAwareDataConverter(converter.GetDefaultDataConverter())

func TestContextAwareDataConverter(t *testing.T) {
t.Parallel()
t.Run("default", func(t *testing.T) {
t.Parallel()
payload, _ := contextAwareDataConverter.ToPayload("test")
result := contextAwareDataConverter.ToString(payload)

require.Equal(t, `"test"`, result)
})
t.Run("implements ContextAware", func(t *testing.T) {
t.Parallel()
_, ok := contextAwareDataConverter.(ContextAware)
require.True(t, ok)
})
t.Run("with activity context", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctx = context.WithValue(ctx, prefixContextKey, "testing")

dc := WithContext(ctx, contextAwareDataConverter)

payload, _ := dc.ToPayload("test")
result := dc.ToString(payload)

require.Equal(t, `testing: "test"`, result)
})
t.Run("with workflow context", func(t *testing.T) {
t.Parallel()
ctx := Background()
ctx = WithValue(ctx, prefixContextKey, "testing")

dc := WithWorkflowContext(ctx, contextAwareDataConverter)

payload, _ := dc.ToPayload("test")
result := dc.ToString(payload)

require.Equal(t, `testing: "test"`, result)
})
}
2 changes: 1 addition & 1 deletion internal/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ func (wc *workflowEnvironmentInterceptor) ExecuteChildWorkflow(ctx Context, chil
executionFuture: executionFuture.(*futureImpl),
}
workflowOptionsFromCtx := getWorkflowEnvOptions(ctx)
dc := workflowOptionsFromCtx.DataConverter
dc := WithWorkflowContext(ctx, workflowOptionsFromCtx.DataConverter)
env := getWorkflowEnvironment(ctx)
wfType, input, err := getValidatedWorkflowFunction(childWorkflowType, args, dc, env.GetRegistry())
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions workflow/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ import (
// Context's methods may be called by multiple goroutines simultaneously.
type Context = internal.Context

// ContextAware is an optional interface that can be implemented alongside DataConverter.
// This interface allows Temporal to pass Workflow/Activity contexts to the DataConverter
// so that it may tailor it's behaviour.
type ContextAware = internal.ContextAware

// ErrCanceled is the error returned by Context.Err when the context is canceled.
var ErrCanceled = internal.ErrCanceled

Expand Down

0 comments on commit d46c109

Please sign in to comment.