Skip to content

Commit

Permalink
Prevent modifying workflow state in read only contexts (temporalio#1159)
Browse files Browse the repository at this point in the history
Prevent modifying workflow state in read only contexts
  • Loading branch information
Quinn-With-Two-Ns authored Jul 17, 2023
1 parent e1d76b7 commit 4714f38
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 11 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/docker/dynamic-config-custom.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@ system.forceSearchAttributesCacheRefreshOnRead:
- value: true # Dev setup only. Please don't turn this on in production.
constraints: {}
system.enableActivityEagerExecution:
- value: true
frontend.enableUpdateWorkflowExecution:
- value: true
frontend.enableUpdateWorkflowExecutionAsyncAccepted:
- value: true
11 changes: 8 additions & 3 deletions internal/internal_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ type (
Complete(success interface{}, err error)
}

// UpdateScheduluer allows an update state machine to spawn coroutines and
// UpdateScheduler allows an update state machine to spawn coroutines and
// yield itself as necessary.
UpdateScheduler interface {
// Spawn starts a new named coroutine, executing the given function f.
Expand Down Expand Up @@ -231,7 +231,7 @@ func (up *updateProtocol) checkAcceptedEvent(e *historypb.HistoryEvent) bool {
attrs.AcceptedRequest != nil
}

// defaultHandler receives the initial invocation of an upate during WFT
// defaultHandler receives the initial invocation of an update during WFT
// processing. The implementation will verify that an updateHandler exists for
// the supplied name (rejecting the update otherwise) and use the provided spawn
// function to create a new coroutine that will execute in the workflow context.
Expand Down Expand Up @@ -289,7 +289,12 @@ func defaultUpdateHandler(
if !IsReplaying(ctx) {
// we don't execute update validation during replay so that
// validation routines can change across versions
if err := envInterceptor.inboundInterceptor.ValidateUpdate(ctx, &input); err != nil {
err = func() error {
defer getState(ctx).dispatcher.setIsReadOnly(false)
getState(ctx).dispatcher.setIsReadOnly(true)
return envInterceptor.inboundInterceptor.ValidateUpdate(ctx, &input)
}()
if err != nil {
callbacks.Reject(err)
return
}
Expand Down
19 changes: 17 additions & 2 deletions internal/internal_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ func TestUpdateHandlerPanicHandling(t *testing.T) {
}
interceptor, ctx, err := newWorkflowContext(env, nil)
require.NoError(t, err)
dispatcher, ctx := newDispatcher(
ctx,
interceptor,
func(ctx Context) {})
dispatcher.executing = true

panicFunc := func() error { panic("intentional") }
mustSetUpdateHandler(t, ctx, t.Name(), panicFunc, UpdateHandlerOptions{Validator: panicFunc})
Expand Down Expand Up @@ -176,8 +181,13 @@ func TestDefaultUpdateHandler(t *testing.T) {
TaskQueueName: "taskqueue:" + t.Name(),
},
}
_, ctx, err := newWorkflowContext(env, nil)
interceptor, ctx, err := newWorkflowContext(env, nil)
require.NoError(t, err)
dispatcher, ctx := newDispatcher(
ctx,
interceptor,
func(ctx Context) {})
dispatcher.executing = true

hdr := &commonpb.Header{Fields: map[string]*commonpb.Payload{}}
argStr := t.Name()
Expand Down Expand Up @@ -288,8 +298,13 @@ func TestDefaultUpdateHandler(t *testing.T) {
// don't reuse the context that has all the other update handlers
// registered because the code under test will think the handler
// registration at workflow start time has already occurred
_, ctx, err := newWorkflowContext(env, nil)
interceptor, ctx, err := newWorkflowContext(env, nil)
require.NoError(t, err)
dispatcher, ctx := newDispatcher(
ctx,
interceptor,
func(ctx Context) {})
dispatcher.executing = true

updateFunc := func(ctx Context, s string) (string, error) { return s + " success!", nil }
var (
Expand Down
35 changes: 31 additions & 4 deletions internal/internal_workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import (
const (
defaultSignalChannelSize = 100000 // really large buffering size(100K)

panicIllegalAccessCoroutinueState = "getState: illegal access from outside of workflow context"
panicIllegalAccessCoroutineState = "getState: illegal access from outside of workflow context"
)

type (
Expand Down Expand Up @@ -170,6 +170,7 @@ type (
closed bool
interceptor WorkflowOutboundInterceptor
deadlockDetector *deadlockDetector
readOnly bool
}

// WorkflowOptions options passed to the workflow function
Expand Down Expand Up @@ -315,6 +316,7 @@ func getWorkflowOutboundInterceptor(ctx Context) WorkflowOutboundInterceptor {
}

func (f *futureImpl) Get(ctx Context, valuePtr interface{}) error {
assertNotInReadOnlyState(ctx)
more := f.channel.Receive(ctx, nil)
if more {
panic("not closed")
Expand Down Expand Up @@ -436,6 +438,7 @@ func (f *childWorkflowFutureImpl) GetChildWorkflowExecution() Future {
}

func (f *childWorkflowFutureImpl) SignalChildWorkflow(ctx Context, signalName string, data interface{}) Future {
assertNotInReadOnlyState(ctx)
var childExec WorkflowExecution
if err := f.GetChildWorkflowExecution().Get(ctx, &childExec); err != nil {
return f.GetChildWorkflowExecution()
Expand Down Expand Up @@ -646,11 +649,20 @@ func getState(ctx Context) *coroutineState {
}
state := s.(*coroutineState)
if !state.dispatcher.IsExecuting() {
panic(panicIllegalAccessCoroutinueState)
panic(panicIllegalAccessCoroutineState)
}
return state
}

func assertNotInReadOnlyState(ctx Context) {
state := getState(ctx)
// use the dispatcher state instead of the coroutine state because contexts can be
// shared
if state.dispatcher.getIsReadOnly() {
panic(panicIllegalAccessCoroutineState)
}
}

func getStateIfRunning(ctx Context) *coroutineState {
if ctx == nil {
return nil
Expand All @@ -675,6 +687,7 @@ func (c *channelImpl) CanSendWithoutBlocking() bool {
}

func (c *channelImpl) Receive(ctx Context, valuePtr interface{}) (more bool) {
assertNotInReadOnlyState(ctx)
state := getState(ctx)
hasResult := false
var result interface{}
Expand Down Expand Up @@ -1103,6 +1116,18 @@ func (d *dispatcherImpl) IsExecuting() bool {
return d.executing
}

func (d *dispatcherImpl) getIsReadOnly() bool {
d.mutex.Lock()
defer d.mutex.Unlock()
return d.readOnly
}

func (d *dispatcherImpl) setIsReadOnly(readOnly bool) {
d.mutex.Lock()
defer d.mutex.Unlock()
d.readOnly = readOnly
}

func (d *dispatcherImpl) Close() {
d.mutex.Lock()
if d.closed {
Expand Down Expand Up @@ -1170,6 +1195,7 @@ func (s *selectorImpl) HasPending() bool {
}

func (s *selectorImpl) Select(ctx Context) {
assertNotInReadOnlyState(ctx)
state := getState(ctx)
var readyBranch func()
var cleanups []func()
Expand Down Expand Up @@ -1521,7 +1547,7 @@ func (h *queryHandler) execute(input []interface{}) (result interface{}, err err
if p := recover(); p != nil {
result = nil
st := getStackTraceRaw("query handler [panic]:", 7, 0)
if p == panicIllegalAccessCoroutinueState {
if p == panicIllegalAccessCoroutineState {
// query handler code try to access workflow functions outside of workflow context, make error message
// more descriptive and clear.
p = "query handler must not use temporal context to do things like workflow.NewChannel(), " +
Expand Down Expand Up @@ -1567,11 +1593,12 @@ func (wg *waitGroupImpl) Done() {
wg.Add(-1)
}

// Wait blocks and waits for specified number of couritines to
// Wait blocks and waits for specified number of coroutines to
// finish executing and then unblocks once the counter has reached 0.
//
// param ctx Context -> workflow context
func (wg *waitGroupImpl) Wait(ctx Context) {
assertNotInReadOnlyState(ctx)
if wg.n <= 0 {
return
}
Expand Down
101 changes: 101 additions & 0 deletions internal/internal_workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,107 @@ func (s *WorkflowUnitTest) Test_WaitGroupWorkflowTest() {
s.Equal(n, total)
}

func (s *WorkflowUnitTest) Test_MutatingFunctionsInSideEffect() {
env := s.NewTestWorkflowEnvironment()

wf := func(ctx Context) error {
SideEffect(ctx, func(ctx Context) interface{} {
_ = Sleep(ctx, time.Minute)
return nil
})
return nil
}
env.RegisterWorkflow(wf)
env.ExecuteWorkflow(wf)
s.True(env.IsWorkflowCompleted())
s.Error(env.GetWorkflowError())
}

func (s *WorkflowUnitTest) Test_MutatingFunctionsInMutableSideEffect() {
env := s.NewTestWorkflowEnvironment()

wf := func(ctx Context) error {
MutableSideEffect(ctx, "test-side-effect", func(ctx Context) interface{} {
_ = Sleep(ctx, time.Minute)
return nil
}, func(a, b interface{}) bool { return false })
return nil
}
env.RegisterWorkflow(wf)
env.ExecuteWorkflow(wf)
s.True(env.IsWorkflowCompleted())
s.Error(env.GetWorkflowError())
}

func (s *WorkflowUnitTest) Test_MutatingFunctionsInQueries() {
env := s.NewTestWorkflowEnvironment()

wf := func(ctx Context) error {
currentState := "fail"
_ = SetQueryHandler(ctx, queryType, func() (string, error) {
_ = Sleep(ctx, time.Minute)
return currentState, nil
})
_ = Sleep(ctx, time.Minute)
return nil
}
env.RegisterWorkflow(wf)
env.RegisterDelayedCallback(func() {
_, err := env.QueryWorkflow(queryType, "test")
s.Error(err)
}, time.Second)
env.ExecuteWorkflow(wf)
s.True(env.IsWorkflowCompleted())
s.NoError(env.GetWorkflowError())
}

type updateCallback struct {
accept func()
reject func(error)
complete func(interface{}, error)
}

func (uc *updateCallback) Accept() {
uc.accept()
}

func (uc *updateCallback) Reject(err error) {
uc.reject(err)
}

func (uc *updateCallback) Complete(success interface{}, err error) {
uc.complete(success, err)
}

func (s *WorkflowUnitTest) Test_MutatingFunctionsInUpdateValidator() {
env := s.NewTestWorkflowEnvironment()

wf := func(ctx Context) error {
currentState := "fail"
_ = SetUpdateHandler(ctx, updateType, func(ctx Context) (string, error) {
_ = Sleep(ctx, time.Minute)
return currentState, nil
}, UpdateHandlerOptions{
Validator: func(ctx Context) error {
return Sleep(ctx, time.Minute)
},
})
_ = Sleep(ctx, time.Minute)
return nil
}
env.RegisterWorkflow(wf)
env.RegisterDelayedCallback(func() {
env.UpdateWorkflow(updateType, &updateCallback{
reject: func(err error) {
s.Error(err)
},
})
}, time.Second)
env.ExecuteWorkflow(wf)
s.True(env.IsWorkflowCompleted())
s.NoError(env.GetWorkflowError())
}

func (s *WorkflowUnitTest) Test_StaleGoroutinesAreShutDown() {
env := s.NewTestWorkflowEnvironment()
deferred := make(chan struct{})
Expand Down
Loading

0 comments on commit 4714f38

Please sign in to comment.