Skip to content

Commit

Permalink
Refactor NextCommandEvents (temporalio#1205)
Browse files Browse the repository at this point in the history
Refactor NextCommandEvents
  • Loading branch information
Quinn-With-Two-Ns authored Aug 23, 2023
1 parent 20c550a commit b9e5e24
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 65 deletions.
92 changes: 61 additions & 31 deletions internal/internal_task_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ type (
historyMismatchError struct {
message string
}

preparedTask struct {
events []*historypb.HistoryEvent
markers []*historypb.HistoryEvent
flags []sdkFlag
msgs []*protocolpb.Message
binaryChecksum string
}
)

func newHistory(task *workflowTask, eventsHandler *workflowExecutionEventHandlerImpl) *history {
Expand Down Expand Up @@ -301,23 +309,40 @@ func isCommandEvent(eventType enumspb.EventType) bool {
}
}

// NextCommandEvents returns events that there processed as new by the next command.
// TODO(maxim): Refactor to return a struct instead of multiple parameters
func (eh *history) NextCommandEvents() (result []*historypb.HistoryEvent, markers []*historypb.HistoryEvent, binaryChecksum string, sdkFlags []sdkFlag, msgs []*protocolpb.Message, err error) {
// nextTask returns the next task to be processed.
func (eh *history) nextTask() (*preparedTask, error) {
if eh.next == nil {
eh.next, _, eh.nextFlags, _, err = eh.nextCommandEvents()
firstTask, err := eh.prepareTask()
if err != nil {
return result, markers, eh.binaryChecksum, sdkFlags, msgs, err
return nil, err
}
eh.next = firstTask.events
eh.nextFlags = firstTask.flags
}

result = eh.next
result := eh.next
checksum := eh.binaryChecksum
sdkFlags = eh.nextFlags
sdkFlags := eh.nextFlags

var markers []*historypb.HistoryEvent
var msgs []*protocolpb.Message
if len(result) > 0 {
eh.next, markers, eh.nextFlags, msgs, err = eh.nextCommandEvents()
nextTaskEvents, err := eh.prepareTask()
if err != nil {
return nil, err
}
eh.next = nextTaskEvents.events
eh.nextFlags = nextTaskEvents.flags
markers = nextTaskEvents.markers
msgs = nextTaskEvents.msgs
}
return result, markers, checksum, sdkFlags, msgs, err
return &preparedTask{
events: result,
markers: markers,
flags: sdkFlags,
msgs: msgs,
binaryChecksum: checksum,
}, nil
}

func (eh *history) hasMoreEvents() bool {
Expand Down Expand Up @@ -345,38 +370,38 @@ func (eh *history) verifyAllEventsProcessed() error {
return nil
}

func (eh *history) nextCommandEvents() (nextEvents []*historypb.HistoryEvent, markers []*historypb.HistoryEvent, sdkFlags []sdkFlag, msgs []*protocolpb.Message, err error) {
func (eh *history) prepareTask() (*preparedTask, error) {
if eh.currentIndex == len(eh.loadedEvents) && !eh.hasMoreEvents() {
if err := eh.verifyAllEventsProcessed(); err != nil {
return nil, nil, nil, nil, err
return nil, err
}
return []*historypb.HistoryEvent{}, []*historypb.HistoryEvent{}, []sdkFlag{}, []*protocolpb.Message{}, nil
return &preparedTask{}, nil
}

// Process events

var taskEvents preparedTask
OrderEvents:
for {
// load more history events if needed
for eh.currentIndex == len(eh.loadedEvents) {
if !eh.hasMoreEvents() {
if err = eh.verifyAllEventsProcessed(); err != nil {
return
if err := eh.verifyAllEventsProcessed(); err != nil {
return nil, err
}
break OrderEvents
}
if err = eh.loadMoreEvents(); err != nil {
return
if err := eh.loadMoreEvents(); err != nil {
return nil, err
}
}

event := eh.loadedEvents[eh.currentIndex]
eventID := event.GetEventId()
if eventID != eh.nextEventID {
err = fmt.Errorf(
err := fmt.Errorf(
"missing history events, expectedNextEventID=%v but receivedNextEventID=%v",
eh.nextEventID, eventID)
return
return nil, err
}

eh.nextEventID++
Expand All @@ -385,14 +410,14 @@ OrderEvents:
case enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED:
isFailed, binaryChecksum, newFlags, err1 := eh.IsNextWorkflowTaskFailed()
if err1 != nil {
err = err1
return
err := err1
return nil, err
}
if !isFailed {
eh.binaryChecksum = binaryChecksum
eh.currentIndex++
nextEvents = append(nextEvents, event)
sdkFlags = append(sdkFlags, newFlags...)
taskEvents.events = append(taskEvents.events, event)
taskEvents.flags = append(taskEvents.flags, newFlags...)
break OrderEvents
}
case enumspb.EVENT_TYPE_WORKFLOW_TASK_SCHEDULED,
Expand All @@ -401,11 +426,11 @@ OrderEvents:
// Skip
default:
if isPreloadMarkerEvent(event) {
markers = append(markers, event)
taskEvents.markers = append(taskEvents.markers, event)
} else if attrs := event.GetWorkflowExecutionUpdateAcceptedEventAttributes(); attrs != nil {
msgs = append(msgs, inferMessage(attrs))
taskEvents.msgs = append(taskEvents.msgs, inferMessage(attrs))
}
nextEvents = append(nextEvents, event)
taskEvents.events = append(taskEvents.events, event)
}
eh.currentIndex++
}
Expand All @@ -421,7 +446,7 @@ OrderEvents:

eh.currentIndex = 0

return nextEvents, markers, sdkFlags, msgs, nil
return &taskEvents, nil
}

func isPreloadMarkerEvent(event *historypb.HistoryEvent) bool {
Expand Down Expand Up @@ -920,7 +945,15 @@ func (w *workflowExecutionContextImpl) ProcessWorkflowTask(workflowTask *workflo

ProcessEvents:
for {
reorderedEvents, markers, binaryChecksum, flags, historyMessages, err := reorderedHistory.NextCommandEvents()
nextTask, err := reorderedHistory.nextTask()
if err != nil {
return nil, err
}
reorderedEvents := nextTask.events
markers := nextTask.markers
historyMessages := nextTask.msgs
flags := nextTask.flags
binaryChecksum := nextTask.binaryChecksum
// Check if we are replaying so we know if we should use the messages in the WFT or the history
isReplay := len(reorderedEvents) > 0 && reorderedHistory.IsReplayEvent(reorderedEvents[len(reorderedEvents)-1])
var msgs *eventMsgIndex
Expand All @@ -931,9 +964,6 @@ ProcessEvents:
taskMessages = []*protocolpb.Message{}
}

if err != nil {
return nil, err
}
eventHandler.sdkFlags.set(flags...)
if len(reorderedEvents) == 0 {
break ProcessEvents
Expand Down
68 changes: 34 additions & 34 deletions internal/internal_task_handlers_interfaces_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,13 @@ func (s *PollLayerInterfacesTestSuite) TestGetNextCommands() {

eh := newHistory(workflowTask, nil)

events, _, _, _, _, err := eh.NextCommandEvents()
nextTask, err := eh.nextTask()

s.NoError(err)
s.Equal(3, len(events))
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED, events[1].GetEventType())
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED, events[2].GetEventType())
s.Equal(int64(7), events[2].GetEventId())
s.Equal(3, len(nextTask.events))
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED, nextTask.events[1].GetEventType())
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED, nextTask.events[2].GetEventType())
s.Equal(int64(7), nextTask.events[2].GetEventId())
}

func (s *PollLayerInterfacesTestSuite) TestGetNextCommandsSdkFlags() {
Expand Down Expand Up @@ -232,26 +232,26 @@ func (s *PollLayerInterfacesTestSuite) TestGetNextCommandsSdkFlags() {

eh := newHistory(workflowTask, nil)

events, _, _, sdkFlags, _, err := eh.NextCommandEvents()
nextTask, err := eh.nextTask()

s.NoError(err)
s.Equal(2, len(events))
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED, events[1].GetEventType())
s.Equal(2, len(nextTask.events))
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED, nextTask.events[1].GetEventType())
// Verify the SDK flags are fetched at the correct point so they will be applied when the workflow
// function is run.
s.Equal(1, len(sdkFlags))
s.EqualValues(SDKFlagLimitChangeVersionSASize, sdkFlags[0])
s.Equal(1, len(nextTask.flags))
s.EqualValues(SDKFlagLimitChangeVersionSASize, nextTask.flags[0])

events, _, _, sdkFlags, _, err = eh.NextCommandEvents()
nextTask, err = eh.nextTask()

s.NoError(err)
s.Equal(4, len(events))
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_COMPLETED, events[0].GetEventType())
s.Equal(enumspb.EVENT_TYPE_MARKER_RECORDED, events[1].GetEventType())
s.Equal(enumspb.EVENT_TYPE_UPSERT_WORKFLOW_SEARCH_ATTRIBUTES, events[2].GetEventType())
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED, events[3].GetEventType())
s.Equal(4, len(nextTask.events))
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_COMPLETED, nextTask.events[0].GetEventType())
s.Equal(enumspb.EVENT_TYPE_MARKER_RECORDED, nextTask.events[1].GetEventType())
s.Equal(enumspb.EVENT_TYPE_UPSERT_WORKFLOW_SEARCH_ATTRIBUTES, nextTask.events[2].GetEventType())
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED, nextTask.events[3].GetEventType())

s.Equal(0, len(sdkFlags))
s.Equal(0, len(nextTask.flags))
}

func (s *PollLayerInterfacesTestSuite) TestMessageCommands() {
Expand Down Expand Up @@ -299,23 +299,23 @@ func (s *PollLayerInterfacesTestSuite) TestMessageCommands() {

eh := newHistory(workflowTask, nil)

events, _, _, _, msgs, err := eh.NextCommandEvents()
nextTask, err := eh.nextTask()
s.NoError(err)
s.Equal(2, len(events))
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED, events[0].GetEventType())
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED, events[1].GetEventType())
s.Equal(2, len(nextTask.events))
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED, nextTask.events[0].GetEventType())
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED, nextTask.events[1].GetEventType())

s.Equal(1, len(msgs))
s.Equal("test", msgs[0].GetProtocolInstanceId())
s.Equal(1, len(nextTask.msgs))
s.Equal("test", nextTask.msgs[0].GetProtocolInstanceId())

events, _, _, _, msgs, err = eh.NextCommandEvents()
nextTask, err = eh.nextTask()
s.NoError(err)
s.Equal(3, len(events))
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_COMPLETED, events[0].GetEventType())
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED, events[1].GetEventType())
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED, events[2].GetEventType())
s.Equal(3, len(nextTask.events))
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_COMPLETED, nextTask.events[0].GetEventType())
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED, nextTask.events[1].GetEventType())
s.Equal(enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED, nextTask.events[2].GetEventType())

s.Equal(0, len(msgs))
s.Equal(0, len(nextTask.msgs))
}

func (s *PollLayerInterfacesTestSuite) TestEmptyPages() {
Expand Down Expand Up @@ -421,16 +421,16 @@ func (s *PollLayerInterfacesTestSuite) TestEmptyPages() {
}

for _, expected := range expectedResults {
result, _, _, _, msgs, err := eh.NextCommandEvents()
nexTask, err := eh.nextTask()
s.NoError(err)
s.Equal(len(expected.events), len(result))
for i, event := range result {
s.Equal(len(expected.events), len(nexTask.events))
for i, event := range nexTask.events {
s.Equal(expected.events[i].EventId, event.EventId)
s.Equal(expected.events[i].EventType, event.EventType)
}

s.Equal(len(expected.messages), len(msgs))
for i, msg := range msgs {
s.Equal(len(expected.messages), len(nexTask.msgs))
for i, msg := range nexTask.msgs {
s.Equal(expected.messages[i].ProtocolInstanceId, msg.ProtocolInstanceId)
}
}
Expand Down

0 comments on commit b9e5e24

Please sign in to comment.