Skip to content

Commit

Permalink
Broke channel into SendChannel and ReceiveChannel interfaces (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfateev committed Apr 14, 2020
1 parent 3ca1c77 commit 6d109bf
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 62 deletions.
4 changes: 2 additions & 2 deletions internal/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,10 @@ func propagateCancel(parent Context, child canceler) {
} else {
go func() {
s := NewSelector(parent)
s.AddReceive(parent.Done(), func(c Channel, more bool) {
s.AddReceive(parent.Done(), func(c ReceiveChannel, more bool) {
child.cancel(false, parent.Err())
})
s.AddReceive(child.Done(), func(c Channel, more bool) {})
s.AddReceive(child.Done(), func(c ReceiveChannel, more bool) {})
s.Select(parent)
}()
}
Expand Down
4 changes: 2 additions & 2 deletions internal/interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type WorkflowInterceptor interface {
RequestCancelExternalWorkflow(ctx Context, workflowID, runID string) Future
SignalExternalWorkflow(ctx Context, workflowID, runID, signalName string, arg interface{}) Future
UpsertSearchAttributes(ctx Context, attributes map[string]interface{}) error
GetSignalChannel(ctx Context, signalName string) Channel
GetSignalChannel(ctx Context, signalName string) ReceiveChannel
SideEffect(ctx Context, f func(ctx Context) interface{}) Value
MutableSideEffect(ctx Context, id string, f func(ctx Context) interface{}, equals func(a, b interface{}) bool) Value
GetVersion(ctx Context, changeID string, minSupported, maxSupported Version) Version
Expand Down Expand Up @@ -146,7 +146,7 @@ func (t *WorkflowInterceptorBase) UpsertSearchAttributes(ctx Context, attributes
}

// GetSignalChannel forwards to t.Next
func (t *WorkflowInterceptorBase) GetSignalChannel(ctx Context, signalName string) Channel {
func (t *WorkflowInterceptorBase) GetSignalChannel(ctx Context, signalName string) ReceiveChannel {
return t.Next.GetSignalChannel(ctx, signalName)
}

Expand Down
46 changes: 38 additions & 8 deletions internal/internal_coroutines_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func requireNoExecuteErr(t *testing.T, err error) {
func TestDispatcher(t *testing.T) {
value := "foo"
d, _ := newDispatcher(createRootTestContext(), func(ctx Context) { value = "bar" })
defer d.Close()
require.Equal(t, "foo", value)
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone())
Expand All @@ -69,6 +70,7 @@ func TestNonBlockingChildren(t *testing.T) {
}
history = append(history, "root")
})
defer d.Close()
require.EqualValues(t, 0, len(history))
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone())
Expand All @@ -92,6 +94,7 @@ func TestNonbufferedChannel(t *testing.T) {
history = append(history, "root-after-channel-put")

})
defer d.Close()
require.EqualValues(t, 0, len(history))
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone())
Expand Down Expand Up @@ -138,6 +141,7 @@ func TestNonbufferedChannelBlockedReceive(t *testing.T) {
history = append(history, "root-after-channel-put")

})
defer d.Close()
require.EqualValues(t, 0, len(history))
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
c2.SendAsync("value21")
Expand Down Expand Up @@ -169,6 +173,7 @@ func TestBufferedChannelPut(t *testing.T) {
c1.Send(ctx, "value2")
history = append(history, "root-after-channel-put2")
})
defer d.Close()
require.EqualValues(t, 0, len(history))
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone())
Expand Down Expand Up @@ -219,6 +224,7 @@ func TestBufferedChannelGet(t *testing.T) {
c1.Send(ctx, "value2")
history = append(history, "root-after-channel-put2")
})
defer d.Close()
require.EqualValues(t, 0, len(history))
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone(), strings.Join(history, "\n")+"\n\n"+d.StackTrace())
Expand Down Expand Up @@ -246,13 +252,13 @@ func TestNotBlockingSelect(t *testing.T) {
c2 := NewBufferedChannel(ctx, 1)
s := NewSelector(ctx)
s.
AddReceive(c1, func(c Channel, more bool) {
AddReceive(c1, func(c ReceiveChannel, more bool) {
require.True(t, more)
var v string
c.Receive(ctx, &v)
history = append(history, fmt.Sprintf("c1-%v", v))
}).
AddReceive(c2, func(c Channel, more bool) {
AddReceive(c2, func(c ReceiveChannel, more bool) {
require.True(t, more)
var v string
c.Receive(ctx, &v)
Expand All @@ -265,6 +271,7 @@ func TestNotBlockingSelect(t *testing.T) {
s.Select(ctx)
s.Select(ctx)
})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone())

Expand Down Expand Up @@ -295,13 +302,13 @@ func TestBlockingSelect(t *testing.T) {

s := NewSelector(ctx)
s.
AddReceive(c1, func(c Channel, more bool) {
AddReceive(c1, func(c ReceiveChannel, more bool) {
require.True(t, more)
var v string
c.Receive(ctx, &v)
history = append(history, fmt.Sprintf("c1-%v", v))
}).
AddReceive(c2, func(c Channel, more bool) {
AddReceive(c2, func(c ReceiveChannel, more bool) {
var v string
c.Receive(ctx, &v)
history = append(history, fmt.Sprintf("c2-%v", v))
Expand All @@ -312,6 +319,7 @@ func TestBlockingSelect(t *testing.T) {
s.Select(ctx)
history = append(history, "done")
})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone(), strings.Join(history, "\n"))

Expand All @@ -336,7 +344,7 @@ func TestBlockingSelectAsyncSend(t *testing.T) {
c1 := NewChannel(ctx)
s := NewSelector(ctx)
s.
AddReceive(c1, func(c Channel, more bool) {
AddReceive(c1, func(c ReceiveChannel, more bool) {
require.True(t, more)
var v int
c.Receive(ctx, &v)
Expand All @@ -353,6 +361,7 @@ func TestBlockingSelectAsyncSend(t *testing.T) {
}
history = append(history, "done")
})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone(), strings.Join(history, "\n"))

Expand Down Expand Up @@ -380,7 +389,7 @@ func TestSelectOnClosedChannel(t *testing.T) {

selector := NewNamedSelector(ctx, "waiting for channel")

selector.AddReceive(c, func(f Channel, more bool) {
selector.AddReceive(c, func(f ReceiveChannel, more bool) {
var n int

if !more {
Expand All @@ -401,6 +410,7 @@ func TestSelectOnClosedChannel(t *testing.T) {
selector.Select(ctx)
selector.Select(ctx)
})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone(), strings.Join(history, "\n"))

Expand All @@ -420,13 +430,13 @@ func TestBlockingSelectAsyncSend2(t *testing.T) {
c2 := NewBufferedChannel(ctx, 100)
s := NewSelector(ctx)
s.
AddReceive(c1, func(c Channel, more bool) {
AddReceive(c1, func(c ReceiveChannel, more bool) {
require.True(t, more)
var v string
c.Receive(ctx, &v)
history = append(history, fmt.Sprintf("c1-%v", v))
}).
AddReceive(c2, func(c Channel, more bool) {
AddReceive(c2, func(c ReceiveChannel, more bool) {
require.True(t, more)
var v string
c.Receive(ctx, &v)
Expand All @@ -443,6 +453,7 @@ func TestBlockingSelectAsyncSend2(t *testing.T) {
s.Select(ctx)
history = append(history, "done")
})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone(), strings.Join(history, "\n"))

Expand Down Expand Up @@ -483,6 +494,7 @@ func TestSendSelect(t *testing.T) {
s.Select(ctx)
history = append(history, "done")
})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone())

Expand Down Expand Up @@ -525,6 +537,7 @@ func TestSendSelectWithAsyncReceive(t *testing.T) {
s.Select(ctx)
history = append(history, "done")
})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone(), strings.Join(history, "\n"))

Expand Down Expand Up @@ -570,6 +583,7 @@ func TestChannelClose(t *testing.T) {
history = append(history, "done")

})
defer d.Close()
require.EqualValues(t, 0, len(history))
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone(), d.StackTrace())
Expand Down Expand Up @@ -599,6 +613,7 @@ func TestSendClosedChannel(t *testing.T) {
})
c.Send(ctx, "baz")
})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone())
}
Expand All @@ -613,6 +628,7 @@ func TestBlockedSendClosedChannel(t *testing.T) {
c.Close()
c.Send(ctx, "baz")
})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone())
}
Expand All @@ -627,6 +643,7 @@ func TestAsyncSendClosedChannel(t *testing.T) {
c.Close()
_ = c.SendAsync("baz")
})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone())
}
Expand Down Expand Up @@ -683,6 +700,7 @@ func TestPanic(t *testing.T) {
history = append(history, "root")
c.Receive(ctx, nil) // blocked forever
})
defer d.Close()
require.EqualValues(t, 0, len(history))
err := d.ExecuteUntilAllBlocked()
require.Error(t, err)
Expand All @@ -699,6 +717,7 @@ func TestAwait(t *testing.T) {
d, _ := newDispatcher(createRootTestContext(), func(ctx Context) {
_ = Await(ctx, func() bool { return flag })
})
defer d.Close()
err := d.ExecuteUntilAllBlocked()
require.NoError(t, err)
require.False(t, d.IsDone())
Expand All @@ -718,6 +737,7 @@ func TestAwaitCancellation(t *testing.T) {
d, _ := newDispatcher(ctx, func(ctx Context) {
awaitError = Await(ctx, func() bool { return false })
})
defer d.Close()
err := d.ExecuteUntilAllBlocked()
require.NoError(t, err)
require.False(t, d.IsDone())
Expand All @@ -737,6 +757,7 @@ func TestAwaitWithTimeoutNoTimeout(t *testing.T) {
d, _ := newDispatcher(createRootTestContext(), func(ctx Context) {
awaitOk, awaitWithTimeoutError = AwaitWithTimeout(ctx, time.Hour, func() bool { return flag })
})
defer d.Close()
err := d.ExecuteUntilAllBlocked()
require.NoError(t, err)
require.False(t, d.IsDone())
Expand All @@ -760,6 +781,7 @@ func TestAwaitWithTimeoutCancellation(t *testing.T) {
d, _ := newDispatcher(ctx, func(ctx Context) {
awaitOk, awaitWithTimeoutError = AwaitWithTimeout(ctx, time.Hour, func() bool { return false })
})
defer d.Close()
err := d.ExecuteUntilAllBlocked()
require.NoError(t, err)
require.False(t, d.IsDone())
Expand Down Expand Up @@ -797,6 +819,7 @@ func TestFutureSetValue(t *testing.T) {
history = append(history, "root-end")

})
defer d.Close()
require.EqualValues(t, 0, len(history))
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.False(t, d.IsDone(), fmt.Sprintf("%v", d.StackTrace()))
Expand Down Expand Up @@ -841,6 +864,7 @@ func TestFutureFail(t *testing.T) {
history = append(history, "root-end")

})
defer d.Close()
require.EqualValues(t, 0, len(history))
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.False(t, d.IsDone(), fmt.Sprintf("%v", d.StackTrace()))
Expand Down Expand Up @@ -898,6 +922,7 @@ func TestFutureSet(t *testing.T) {
})
history = append(history, "root-end")
})
defer d.Close()

require.EqualValues(t, 0, len(history))
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
Expand Down Expand Up @@ -971,6 +996,7 @@ func TestFutureChain(t *testing.T) {
history = append(history, "root-end")

})
defer d.Close()
require.EqualValues(t, 0, len(history))
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.False(t, d.IsDone(), fmt.Sprintf("%v", d.StackTrace()))
Expand Down Expand Up @@ -1036,6 +1062,7 @@ func TestSelectFuture(t *testing.T) {
s.Select(ctx)
history = append(history, "done")
})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone())

Expand Down Expand Up @@ -1085,6 +1112,7 @@ func TestSelectDecodeFuture(t *testing.T) {
s.Select(ctx)
history = append(history, "done")
})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone())

Expand Down Expand Up @@ -1143,6 +1171,7 @@ func TestDecodeFutureChain(t *testing.T) {
})
history = append(history, "root-end")
})
defer d.Close()
require.EqualValues(t, 0, len(history))
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
// set f1
Expand Down Expand Up @@ -1211,6 +1240,7 @@ func TestSelectFuture_WithBatchSets(t *testing.T) {
s.Select(ctx)
s.Select(ctx)
})
defer d.Close()
requireNoExecuteErr(t, d.ExecuteUntilAllBlocked())
require.True(t, d.IsDone())

Expand Down
1 change: 1 addition & 0 deletions internal/internal_task_handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,7 @@ func (t *TaskHandlersTestSuite) TestActivityExecutionWorkerStop() {
UserContext: ctx,
UserContextCancel: cancel,
WorkerStopChannel: workerStopCh,
Tracer: opentracing.NoopTracer{},
}
activityHandler := newActivityTaskHandler(mockService, wep, registry)
pats := &workflowservice.PollForActivityTaskResponse{
Expand Down
31 changes: 15 additions & 16 deletions internal/internal_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -971,30 +971,29 @@ func (aw *AggregatedWorker) Start() error {

if !isInterfaceNil(aw.workflowWorker) {
if len(aw.registry.getRegisteredWorkflowTypes()) == 0 {
aw.logger.Warn(
"Starting worker without any workflows. Workflows must be registered before start.",
)
}
if err := aw.workflowWorker.Start(); err != nil {
return err
aw.logger.Info("No workflows registered. Skipping workflow worker start")
} else {
if err := aw.workflowWorker.Start(); err != nil {
return err
}
}
}
if !isInterfaceNil(aw.activityWorker) {
if len(aw.registry.getRegisteredActivities()) == 0 {
aw.logger.Warn(
"Starting worker without any activities. Activities must be registered before start.",
)
}
if err := aw.activityWorker.Start(); err != nil {
// stop workflow worker.
if !isInterfaceNil(aw.workflowWorker) {
aw.workflowWorker.Stop()
aw.logger.Info("No activities registered. Skipping activity worker start")
} else {
if err := aw.activityWorker.Start(); err != nil {
// stop workflow worker.
if !isInterfaceNil(aw.workflowWorker) && len(aw.registry.getRegisteredWorkflowTypes()) > 0 {
aw.workflowWorker.Stop()
}
return err
}
return err
}
}

if !isInterfaceNil(aw.sessionWorker) {
if !isInterfaceNil(aw.sessionWorker) && len(aw.registry.getRegisteredActivities()) > 0 {
aw.logger.Info("Starting session worker")
if err := aw.sessionWorker.Start(); err != nil {
// stop workflow worker and activity worker.
if !isInterfaceNil(aw.workflowWorker) {
Expand Down
Loading

0 comments on commit 6d109bf

Please sign in to comment.