From 2a582f80c9341cf2d0350ffc12b2fba6fe43ac73 Mon Sep 17 00:00:00 2001 From: Sebastian Neira Date: Thu, 10 Mar 2022 15:03:24 +0100 Subject: [PATCH] Add mock call assertions to TestWorkflowEnvironment (#748) --- internal/workflow_testsuite.go | 24 +++++++++++++ internal/workflow_testsuite_test.go | 55 +++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/internal/workflow_testsuite.go b/internal/workflow_testsuite.go index f21ce4145..010c75816 100644 --- a/internal/workflow_testsuite.go +++ b/internal/workflow_testsuite.go @@ -814,3 +814,27 @@ func (e *TestWorkflowEnvironment) SetSearchAttributesOnStart(searchAttributes ma func (e *TestWorkflowEnvironment) AssertExpectations(t mock.TestingT) bool { return e.mock.AssertExpectations(t) } + +// AssertCalled asserts that the method was called with the supplied arguments. +// Useful to assert that an Activity was called from within a workflow with the expected arguments. +// Since the first argument is a context, consider using mock.Anything for that argument. +// +// env.OnActivity(namedActivity, mock.Anything, mock.Anything).Return("mock_result", nil) +// env.ExecuteWorkflow(workflowThatCallsActivityWithItsArgument, "Hello") +// env.AssertCalled(t, "namedActivity", mock.Anything, "Hello") +// +// It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. +func (e *TestWorkflowEnvironment) AssertCalled(t mock.TestingT, methodName string, arguments ...interface{}) bool { + return e.mock.AssertCalled(t, methodName, arguments...) +} + +// AssertNotCalled asserts that the method was not called with the given arguments. +// See AssertCalled for more info. +func (e *TestWorkflowEnvironment) AssertNotCalled(t mock.TestingT, methodName string, arguments ...interface{}) bool { + return e.mock.AssertNotCalled(t, methodName, arguments...) +} + +// AssertNumberOfCalls asserts that a method was called expectedCalls times. +func (e *TestWorkflowEnvironment) AssertNumberOfCalls(t mock.TestingT, methodName string, expectedCalls int) bool { + return e.mock.AssertNumberOfCalls(t, methodName, expectedCalls) +} diff --git a/internal/workflow_testsuite_test.go b/internal/workflow_testsuite_test.go index 7d58b7399..01ef4f633 100644 --- a/internal/workflow_testsuite_test.go +++ b/internal/workflow_testsuite_test.go @@ -31,6 +31,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -236,3 +237,57 @@ func TestWorkflowStartTimeInsideTestWorkflow(t *testing.T) { require.NoError(t, env.GetWorkflowResult(×tamp)) require.Equal(t, env.Now().Unix(), timestamp) } + +func TestActivityAssertCalled(t *testing.T) { + testSuite := &WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() + + env.RegisterActivity(namedActivity) + env.OnActivity(namedActivity, mock.Anything, mock.Anything).Return("Mock!", nil) + + env.ExecuteWorkflow(func(ctx Context, arg1 string) (string, error) { + ctx = WithLocalActivityOptions(ctx, LocalActivityOptions{ + ScheduleToCloseTimeout: time.Hour, + StartToCloseTimeout: time.Hour, + }) + var result string + err := ExecuteLocalActivity(ctx, "namedActivity", arg1).Get(ctx, &result) + if err != nil { + return "", err + } + return result, nil + }, "Hello") + + require.NoError(t, env.GetWorkflowError()) + var result string + err := env.GetWorkflowResult(&result) + require.NoError(t, err) + + require.Equal(t, "Mock!", result) + env.AssertCalled(t, "namedActivity", mock.Anything, "Hello") + env.AssertNotCalled(t, "namedActivity", mock.Anything, "Bye") +} + +func TestActivityAssertNumberOfCalls(t *testing.T) { + testSuite := &WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() + + env.RegisterActivity(namedActivity) + env.OnActivity(namedActivity, mock.Anything, mock.Anything).Return("Mock!", nil) + + env.ExecuteWorkflow(func(ctx Context, arg1 string) (string, error) { + ctx = WithLocalActivityOptions(ctx, LocalActivityOptions{ + ScheduleToCloseTimeout: time.Hour, + StartToCloseTimeout: time.Hour, + }) + var result string + _ = ExecuteLocalActivity(ctx, "namedActivity", arg1).Get(ctx, &result) + _ = ExecuteLocalActivity(ctx, "namedActivity", arg1).Get(ctx, &result) + _ = ExecuteLocalActivity(ctx, "namedActivity", arg1).Get(ctx, &result) + return result, nil + }, "Hello") + + require.NoError(t, env.GetWorkflowError()) + env.AssertNumberOfCalls(t, "namedActivity", 3) + env.AssertNumberOfCalls(t, "otherActivity", 0) +}