diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index 4f2f11614..3316861a3 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -2215,6 +2215,21 @@ func (env *testWorkflowEnvironmentImpl) queryWorkflow(queryType string, args ... return newEncodedValue(blob, env.GetDataConverter()), nil } +func (env *testWorkflowEnvironmentImpl) queryWorkflowByID(workflowID, queryType string, args ...interface{}) (converter.EncodedValue, error) { + if workflowHandle, ok := env.runningWorkflows[workflowID]; ok { + data, err := encodeArgs(workflowHandle.env.GetDataConverter(), args) + if err != nil { + return nil, err + } + blob, err := workflowHandle.env.queryHandler(queryType, data) + if err != nil { + return nil, err + } + return newEncodedValue(blob, workflowHandle.env.GetDataConverter()), nil + } + return nil, serviceerror.NewNotFound(fmt.Sprintf("Workflow %v not exists", workflowID)) +} + func (env *testWorkflowEnvironmentImpl) getMockRunFn(callWrapper *MockCallWrapper) func(args mock.Arguments) { env.locker.Lock() defer env.locker.Unlock() diff --git a/internal/internal_workflow_testsuite_test.go b/internal/internal_workflow_testsuite_test.go index 26763b352..f409eae88 100644 --- a/internal/internal_workflow_testsuite_test.go +++ b/internal/internal_workflow_testsuite_test.go @@ -1857,6 +1857,79 @@ func (s *WorkflowTestSuiteUnitTest) Test_QueryWorkflow() { verifyStateWithQuery(stateDone) } +func (s *WorkflowTestSuiteUnitTest) Test_QueryChildWorkflow() { + queryType := "state" + childWorkflowID := "test-query-child-workflow" + stateWaitSignal, stateWaitActivity, stateDone := "wait for signal", "wait for activity", "done" + childWorkflowFn := func(ctx Context) error { + var state string + err := SetQueryHandler(ctx, queryType, func(queryInput string) (string, error) { + return queryInput + state, nil + }) + if err != nil { + return err + } + state = stateWaitSignal + var signalData string + GetSignalChannel(ctx, "query-signal").Receive(ctx, &signalData) + + state = stateWaitActivity + ctx = WithActivityOptions(ctx, s.activityOptions) + err = ExecuteActivity(ctx, testActivityHello, signalData).Get(ctx, nil) + if err != nil { + return err + } + state = stateDone + return err + } + workflowFn := func(ctx Context) error { + cwo := ChildWorkflowOptions{ + WorkflowID: childWorkflowID, + WorkflowRunTimeout: time.Hour * 3, + RetryPolicy: &RetryPolicy{ + InitialInterval: time.Second * 3, + MaximumInterval: time.Second * 3, + BackoffCoefficient: 1, + }, + } + ctx = WithChildWorkflowOptions(ctx, cwo) + err := ExecuteChildWorkflow(ctx, childWorkflowFn).Get(ctx, nil) + if err != nil { + return err + } + return err + } + + env := s.NewTestWorkflowEnvironment() + env.RegisterWorkflow(childWorkflowFn) + env.RegisterWorkflow(workflowFn) + env.RegisterActivity(testActivityHello) + verifyStateWithQuery := func(expected string) { + encodedValue, err := env.QueryWorkflowByID(childWorkflowID, queryType, "input") + s.NoError(err) + s.NotNil(encodedValue) + var state string + err = encodedValue.Get(&state) + s.NoError(err) + s.Equal("input"+expected, state) + } + + env.RegisterDelayedCallback(func() { + verifyStateWithQuery(stateWaitSignal) + _ = env.SignalWorkflowByID(childWorkflowID, "query-signal", "hello-query") + }, time.Hour) + env.OnActivity(testActivityHello, mock.Anything, mock.Anything).After(time.Hour).Return("hello_mock", nil) + env.SetOnActivityStartedListener(func(activityInfo *ActivityInfo, ctx context.Context, args converter.EncodedValues) { + verifyStateWithQuery(stateWaitActivity) + }) + env.ExecuteWorkflow(workflowFn) + + s.True(env.IsWorkflowCompleted()) + s.NoError(env.GetWorkflowError()) + env.AssertExpectations(s.T()) + verifyStateWithQuery(stateDone) +} + func (s *WorkflowTestSuiteUnitTest) Test_WorkflowWithLocalActivity() { localActivityFn := func(ctx context.Context, name string) (string, error) { return "hello " + name, nil diff --git a/internal/workflow_testsuite.go b/internal/workflow_testsuite.go index 8eac31ba3..c1eeeb110 100644 --- a/internal/workflow_testsuite.go +++ b/internal/workflow_testsuite.go @@ -714,6 +714,11 @@ func (e *TestWorkflowEnvironment) QueryWorkflow(queryType string, args ...interf return e.impl.queryWorkflow(queryType, args...) } +// QueryWorkflowByID queries a child workflow by its ID and returns the result synchronously +func (e *TestWorkflowEnvironment) QueryWorkflowByID(workflowID, queryType string, args ...interface{}) (converter.EncodedValue, error) { + return e.impl.queryWorkflowByID(workflowID, queryType, args...) +} + // RegisterDelayedCallback creates a new timer with specified delayDuration using workflow clock (not wall clock). When // the timer fires, the callback will be called. By default, this test suite uses mock clock which automatically move // forward to fire next timer when workflow is blocked. Use this API to make some event (like activity completion,