Skip to content

Commit

Permalink
Add mock call assertions to TestWorkflowEnvironment (#748)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Neira committed Mar 10, 2022
1 parent 8608a59 commit 2a582f8
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
24 changes: 24 additions & 0 deletions internal/workflow_testsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,27 @@ func (e *TestWorkflowEnvironment) SetSearchAttributesOnStart(searchAttributes ma
func (e *TestWorkflowEnvironment) AssertExpectations(t mock.TestingT) bool {
return e.mock.AssertExpectations(t)
}

// AssertCalled asserts that the method was called with the supplied arguments.
// Useful to assert that an Activity was called from within a workflow with the expected arguments.
// Since the first argument is a context, consider using mock.Anything for that argument.
//
// env.OnActivity(namedActivity, mock.Anything, mock.Anything).Return("mock_result", nil)
// env.ExecuteWorkflow(workflowThatCallsActivityWithItsArgument, "Hello")
// env.AssertCalled(t, "namedActivity", mock.Anything, "Hello")
//
// It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
func (e *TestWorkflowEnvironment) AssertCalled(t mock.TestingT, methodName string, arguments ...interface{}) bool {
return e.mock.AssertCalled(t, methodName, arguments...)
}

// AssertNotCalled asserts that the method was not called with the given arguments.
// See AssertCalled for more info.
func (e *TestWorkflowEnvironment) AssertNotCalled(t mock.TestingT, methodName string, arguments ...interface{}) bool {
return e.mock.AssertNotCalled(t, methodName, arguments...)
}

// AssertNumberOfCalls asserts that a method was called expectedCalls times.
func (e *TestWorkflowEnvironment) AssertNumberOfCalls(t mock.TestingT, methodName string, expectedCalls int) bool {
return e.mock.AssertNumberOfCalls(t, methodName, expectedCalls)
}
55 changes: 55 additions & 0 deletions internal/workflow_testsuite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -236,3 +237,57 @@ func TestWorkflowStartTimeInsideTestWorkflow(t *testing.T) {
require.NoError(t, env.GetWorkflowResult(&timestamp))
require.Equal(t, env.Now().Unix(), timestamp)
}

func TestActivityAssertCalled(t *testing.T) {
testSuite := &WorkflowTestSuite{}
env := testSuite.NewTestWorkflowEnvironment()

env.RegisterActivity(namedActivity)
env.OnActivity(namedActivity, mock.Anything, mock.Anything).Return("Mock!", nil)

env.ExecuteWorkflow(func(ctx Context, arg1 string) (string, error) {
ctx = WithLocalActivityOptions(ctx, LocalActivityOptions{
ScheduleToCloseTimeout: time.Hour,
StartToCloseTimeout: time.Hour,
})
var result string
err := ExecuteLocalActivity(ctx, "namedActivity", arg1).Get(ctx, &result)
if err != nil {
return "", err
}
return result, nil
}, "Hello")

require.NoError(t, env.GetWorkflowError())
var result string
err := env.GetWorkflowResult(&result)
require.NoError(t, err)

require.Equal(t, "Mock!", result)
env.AssertCalled(t, "namedActivity", mock.Anything, "Hello")
env.AssertNotCalled(t, "namedActivity", mock.Anything, "Bye")
}

func TestActivityAssertNumberOfCalls(t *testing.T) {
testSuite := &WorkflowTestSuite{}
env := testSuite.NewTestWorkflowEnvironment()

env.RegisterActivity(namedActivity)
env.OnActivity(namedActivity, mock.Anything, mock.Anything).Return("Mock!", nil)

env.ExecuteWorkflow(func(ctx Context, arg1 string) (string, error) {
ctx = WithLocalActivityOptions(ctx, LocalActivityOptions{
ScheduleToCloseTimeout: time.Hour,
StartToCloseTimeout: time.Hour,
})
var result string
_ = ExecuteLocalActivity(ctx, "namedActivity", arg1).Get(ctx, &result)
_ = ExecuteLocalActivity(ctx, "namedActivity", arg1).Get(ctx, &result)
_ = ExecuteLocalActivity(ctx, "namedActivity", arg1).Get(ctx, &result)
return result, nil
}, "Hello")

require.NoError(t, env.GetWorkflowError())
env.AssertNumberOfCalls(t, "namedActivity", 3)
env.AssertNumberOfCalls(t, "otherActivity", 0)
}

0 comments on commit 2a582f8

Please sign in to comment.