Skip to content

Commit

Permalink
Options validation removal (#102)
Browse files Browse the repository at this point in the history
Service already validates all the passed options. Validating them on the client leads to situations when workflows fail due to a bad option when they could be recovered. And even worse it leads to situations when validation logic on the client gets out of sync with the service.

Also fixed TestWorkflowEnvironment to support mock testing of activities by their string name without any registration. Also removed initialization of a session worker in TestWorkflowEnvironment unless requested.
  • Loading branch information
mfateev committed Apr 14, 2020
1 parent 6d109bf commit 9592ebc
Show file tree
Hide file tree
Showing 12 changed files with 283 additions and 235 deletions.
12 changes: 6 additions & 6 deletions internal/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -551,13 +551,13 @@ const (

// NewClient creates an instance of a workflow client
func NewClient(options ClientOptions) (Client, error) {
if len(options.Namespace) == 0 {
if options.Namespace == "" {
options.Namespace = DefaultNamespace
}

options.MetricsScope = tagScope(options.MetricsScope, tagNamespace, options.Namespace, clientImplHeaderName, clientImplHeaderValue)

if len(options.HostPort) == 0 {
if options.HostPort == "" {
options.HostPort = LocalHostPort
}

Expand All @@ -581,11 +581,11 @@ func NewClient(options ClientOptions) (Client, error) {
// NewServiceClient creates workflow client from workflowservice.WorkflowServiceClient. Must be used internally in unit tests only.
func NewServiceClient(workflowServiceClient workflowservice.WorkflowServiceClient, connectionCloser io.Closer, options ClientOptions) *WorkflowClient {
// Namespace can be empty in unit tests.
if len(options.Namespace) == 0 {
if options.Namespace == "" {
options.Namespace = DefaultNamespace
}

if len(options.Identity) == 0 {
if options.Identity == "" {
options.Identity = getWorkerIdentity("")
}

Expand Down Expand Up @@ -616,7 +616,7 @@ func NewServiceClient(workflowServiceClient workflowservice.WorkflowServiceClien
func NewNamespaceClient(options ClientOptions) (NamespaceClient, error) {
options.MetricsScope = tagScope(options.MetricsScope, clientImplHeaderName, clientImplHeaderValue)

if len(options.HostPort) == 0 {
if options.HostPort == "" {
options.HostPort = LocalHostPort
}

Expand All @@ -638,7 +638,7 @@ func NewNamespaceClient(options ClientOptions) (NamespaceClient, error) {
}

func newNamespaceServiceClient(workflowServiceClient workflowservice.WorkflowServiceClient, clientConn *grpc.ClientConn, options ClientOptions) NamespaceClient {
if len(options.Identity) == 0 {
if options.Identity == "" {
options.Identity = getWorkerIdentity("")
}

Expand Down
64 changes: 0 additions & 64 deletions internal/internal_activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,39 +170,6 @@ func getLocalActivityOptions(ctx Context) *localActivityOptions {
return opts.(*localActivityOptions)
}

func getValidatedActivityOptions(ctx Context) (*activityOptions, error) {
p := getActivityOptions(ctx)
if p == nil {
// We need task list as a compulsory parameter. This can be removed after registration
return nil, errActivityParamsBadRequest
}
if p.TaskListName == "" {
// We default to origin task list name.
p.TaskListName = p.OriginalTaskListName
}
if p.ScheduleToStartTimeoutSeconds <= 0 {
return nil, errors.New("missing or negative ScheduleToStartTimeoutSeconds")
}
if p.StartToCloseTimeoutSeconds <= 0 {
return nil, errors.New("missing or negative StartToCloseTimeoutSeconds")
}
if p.ScheduleToCloseTimeoutSeconds < 0 {
return nil, errors.New("missing or negative ScheduleToCloseTimeoutSeconds")
}
if p.ScheduleToCloseTimeoutSeconds == 0 {
// This is a optional parameter, we default to sum of the other two timeouts.
p.ScheduleToCloseTimeoutSeconds = p.ScheduleToStartTimeoutSeconds + p.StartToCloseTimeoutSeconds
}
if p.HeartbeatTimeoutSeconds < 0 {
return nil, errors.New("invalid negative HeartbeatTimeoutSeconds")
}
if err := validateRetryPolicy(p.RetryPolicy); err != nil {
return nil, err
}

return p, nil
}

func getValidatedLocalActivityOptions(ctx Context) (*localActivityOptions, error) {
p := getLocalActivityOptions(ctx)
if p == nil {
Expand All @@ -215,37 +182,6 @@ func getValidatedLocalActivityOptions(ctx Context) (*localActivityOptions, error
return p, nil
}

func validateRetryPolicy(p *commonpb.RetryPolicy) error {
if p == nil {
return nil
}

if p.GetInitialIntervalInSeconds() <= 0 {
return errors.New("missing or negative InitialIntervalInSeconds on retry policy")
}
if p.GetMaximumIntervalInSeconds() < 0 {
return errors.New("negative MaximumIntervalInSeconds on retry policy is invalid")
}
if p.GetMaximumIntervalInSeconds() == 0 {
// if not set, default to 100x of initial interval
p.MaximumIntervalInSeconds = 100 * p.GetInitialIntervalInSeconds()
}
if p.GetMaximumAttempts() < 0 {
return errors.New("negative MaximumAttempts on retry policy is invalid")
}
if p.GetExpirationIntervalInSeconds() < 0 {
return errors.New("ExpirationIntervalInSeconds cannot be less than 0 on retry policy")
}
if p.GetBackoffCoefficient() < 1 {
return errors.New("BackoffCoefficient on retry policy cannot be less than 1.0")
}
if p.GetMaximumAttempts() == 0 && p.GetExpirationIntervalInSeconds() == 0 {
return errors.New("both MaximumAttempts and ExpirationIntervalInSeconds on retry policy are not set, at least one of them must be set")
}

return nil
}

func validateFunctionArgs(f interface{}, args []interface{}, isWorkflow bool) error {
fType := reflect.TypeOf(f)
if fType == nil || fType.Kind() != reflect.Func {
Expand Down
10 changes: 4 additions & 6 deletions internal/internal_coroutines_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1256,17 +1256,15 @@ func TestChainedFuture(t *testing.T) {
activityFn := func(arg int) (int, error) {
return arg, nil
}
workflowFn := func(ctx Context) (int, error) {
workflowFn := func(ctx Context) (out int, err error) {
ctx = WithActivityOptions(ctx, ActivityOptions{
ScheduleToStartTimeout: time.Minute,
StartToCloseTimeout: time.Minute,
ScheduleToCloseTimeout: time.Minute,
})
f := ExecuteActivity(ctx, activityFn, 5)
var out int
fut, set := NewFuture(ctx)
set.Chain(f)
require.NoError(t, fut.Get(ctx, &out))
return out, nil
err = fut.Get(ctx, &out)
return
}

s := WorkflowTestSuite{}
Expand Down
4 changes: 2 additions & 2 deletions internal/internal_task_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,8 @@ func (wth *workflowTaskHandlerImpl) createWorkflowContext(task *workflowservice.
return nil, errors.New("first history event is not WorkflowExecutionStarted")
}
taskList := attributes.TaskList
if taskList == nil {
return nil, errors.New("nil TaskList in WorkflowExecutionStarted event")
if taskList == nil || taskList.Name == "" {
return nil, errors.New("nil or empty TaskList in WorkflowExecutionStarted event")
}

runID := task.WorkflowExecution.GetRunId()
Expand Down
4 changes: 2 additions & 2 deletions internal/internal_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func verifyNamespaceExist(client workflowservice.WorkflowServiceClient, namespac
return nil
}

if len(namespace) == 0 {
if namespace == "" {
return errors.New("namespace cannot be empty")
}

Expand Down Expand Up @@ -1536,7 +1536,7 @@ func setClientDefaults(client *WorkflowClient) {
if client.dataConverter == nil {
client.dataConverter = getDefaultDataConverter()
}
if len(client.namespace) == 0 {
if client.namespace == "" {
client.namespace = DefaultNamespace
}
if client.tracer == nil {
Expand Down
44 changes: 0 additions & 44 deletions internal/internal_workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import (
"time"
"unicode"

"github.com/robfig/cron"
"go.uber.org/atomic"
"go.uber.org/zap"

Expand Down Expand Up @@ -1150,49 +1149,6 @@ func getValidatedWorkflowFunction(workflowFunc interface{}, args []interface{},
return &WorkflowType{Name: fnName}, input, nil
}

func getValidatedWorkflowOptions(ctx Context) (*workflowOptions, error) {
p := getWorkflowEnvOptions(ctx)
if p == nil {
// We need task list as a compulsory parameter. This can be removed after registration
return nil, errWorkflowOptionBadRequest
}
info := GetWorkflowInfo(ctx)
if p.namespace == "" {
// default to use current workflow's namespace
p.namespace = info.Namespace
}
if p.taskListName == "" {
// default to use current workflow's task list
p.taskListName = info.TaskListName
}
if p.taskStartToCloseTimeoutSeconds < 0 {
return nil, errors.New("missing or negative DecisionTaskStartToCloseTimeout")
}
if p.taskStartToCloseTimeoutSeconds == 0 {
p.taskStartToCloseTimeoutSeconds = defaultDecisionTaskTimeoutInSecs
}
if p.executionStartToCloseTimeoutSeconds <= 0 {
return nil, errors.New("missing or invalid ExecutionStartToCloseTimeout")
}
if err := validateRetryPolicy(p.retryPolicy); err != nil {
return nil, err
}
if err := validateCronSchedule(p.cronSchedule); err != nil {
return nil, err
}

return p, nil
}

func validateCronSchedule(cronSchedule string) error {
if len(cronSchedule) == 0 {
return nil
}

_, err := cron.ParseStandard(cronSchedule)
return err
}

func getWorkflowEnvOptions(ctx Context) *workflowOptions {
options := ctx.Value(workflowEnvOptionsContextKey)
if options != nil {
Expand Down
43 changes: 7 additions & 36 deletions internal/internal_workflow_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ var _ Client = (*WorkflowClient)(nil)
var _ NamespaceClient = (*namespaceClient)(nil)

const (
defaultDecisionTaskTimeoutInSecs = 10
defaultGetHistoryTimeoutInSecs = 25
defaultGetHistoryTimeoutInSecs = 65
)

var (
Expand Down Expand Up @@ -167,22 +166,8 @@ func (wc *WorkflowClient) StartWorkflow(
workflowID = uuid.NewRandom().String()
}

if options.TaskList == "" {
return nil, errors.New("missing TaskList")
}

executionTimeout := common.Int32Ceil(options.ExecutionStartToCloseTimeout.Seconds())
if executionTimeout <= 0 {
return nil, errors.New("missing or invalid ExecutionStartToCloseTimeout")
}

decisionTaskTimeout := common.Int32Ceil(options.DecisionTaskStartToCloseTimeout.Seconds())
if decisionTaskTimeout < 0 {
return nil, errors.New("negative DecisionTaskStartToCloseTimeout provided")
}
if decisionTaskTimeout == 0 {
decisionTaskTimeout = defaultDecisionTaskTimeoutInSecs
}

// Validate type and its arguments.
workflowType, input, err := getValidatedWorkflowFunction(workflowFunc, args, wc.dataConverter, wc.registry)
Expand Down Expand Up @@ -364,22 +349,8 @@ func (wc *WorkflowClient) SignalWithStartWorkflow(ctx context.Context, workflowI
workflowID = uuid.NewRandom().String()
}

if options.TaskList == "" {
return nil, errors.New("missing TaskList")
}

executionTimeout := common.Int32Ceil(options.ExecutionStartToCloseTimeout.Seconds())
if executionTimeout <= 0 {
return nil, errors.New("missing or invalid ExecutionStartToCloseTimeout")
}

decisionTaskTimeout := common.Int32Ceil(options.DecisionTaskStartToCloseTimeout.Seconds())
if decisionTaskTimeout < 0 {
return nil, errors.New("negative DecisionTaskStartToCloseTimeout provided")
}
if decisionTaskTimeout == 0 {
decisionTaskTimeout = defaultDecisionTaskTimeoutInSecs
}

// Validate type and its arguments.
workflowType, input, err := getValidatedWorkflowFunction(workflowFunc, workflowArgs, wc.dataConverter, wc.registry)
Expand Down Expand Up @@ -620,7 +591,7 @@ func (wc *WorkflowClient) RecordActivityHeartbeatByID(ctx context.Context,
// - InternalServiceError
// - EntityNotExistError
func (wc *WorkflowClient) ListClosedWorkflow(ctx context.Context, request *workflowservice.ListClosedWorkflowExecutionsRequest) (*workflowservice.ListClosedWorkflowExecutionsResponse, error) {
if len(request.GetNamespace()) == 0 {
if request.GetNamespace() == "" {
request.Namespace = wc.namespace
}
var response *workflowservice.ListClosedWorkflowExecutionsResponse
Expand All @@ -644,7 +615,7 @@ func (wc *WorkflowClient) ListClosedWorkflow(ctx context.Context, request *workf
// - InternalServiceError
// - EntityNotExistError
func (wc *WorkflowClient) ListOpenWorkflow(ctx context.Context, request *workflowservice.ListOpenWorkflowExecutionsRequest) (*workflowservice.ListOpenWorkflowExecutionsResponse, error) {
if len(request.GetNamespace()) == 0 {
if request.GetNamespace() == "" {
request.Namespace = wc.namespace
}
var response *workflowservice.ListOpenWorkflowExecutionsResponse
Expand All @@ -664,7 +635,7 @@ func (wc *WorkflowClient) ListOpenWorkflow(ctx context.Context, request *workflo

// ListWorkflow implementation
func (wc *WorkflowClient) ListWorkflow(ctx context.Context, request *workflowservice.ListWorkflowExecutionsRequest) (*workflowservice.ListWorkflowExecutionsResponse, error) {
if len(request.GetNamespace()) == 0 {
if request.GetNamespace() == "" {
request.Namespace = wc.namespace
}
var response *workflowservice.ListWorkflowExecutionsResponse
Expand All @@ -684,7 +655,7 @@ func (wc *WorkflowClient) ListWorkflow(ctx context.Context, request *workflowser

// ListArchivedWorkflow implementation
func (wc *WorkflowClient) ListArchivedWorkflow(ctx context.Context, request *workflowservice.ListArchivedWorkflowExecutionsRequest) (*workflowservice.ListArchivedWorkflowExecutionsResponse, error) {
if len(request.GetNamespace()) == 0 {
if request.GetNamespace() == "" {
request.Namespace = wc.namespace
}
var response *workflowservice.ListArchivedWorkflowExecutionsResponse
Expand Down Expand Up @@ -716,7 +687,7 @@ func (wc *WorkflowClient) ListArchivedWorkflow(ctx context.Context, request *wor

// ScanWorkflow implementation
func (wc *WorkflowClient) ScanWorkflow(ctx context.Context, request *workflowservice.ScanWorkflowExecutionsRequest) (*workflowservice.ScanWorkflowExecutionsResponse, error) {
if len(request.GetNamespace()) == 0 {
if request.GetNamespace() == "" {
request.Namespace = wc.namespace
}
var response *workflowservice.ScanWorkflowExecutionsResponse
Expand All @@ -736,7 +707,7 @@ func (wc *WorkflowClient) ScanWorkflow(ctx context.Context, request *workflowser

// CountWorkflow implementation
func (wc *WorkflowClient) CountWorkflow(ctx context.Context, request *workflowservice.CountWorkflowExecutionsRequest) (*workflowservice.CountWorkflowExecutionsResponse, error) {
if len(request.GetNamespace()) == 0 {
if request.GetNamespace() == "" {
request.Namespace = wc.namespace
}
var response *workflowservice.CountWorkflowExecutionsResponse
Expand Down
30 changes: 1 addition & 29 deletions internal/internal_workflow_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ package internal
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"os"
Expand Down Expand Up @@ -232,7 +231,7 @@ func (s *historyEventIteratorSuite) TestIterator_NoError_EmptyPage() {
s.Equal(1, len(events))
}

func (s *historyEventIteratorSuite) TestIterator_Error() {
func (s *historyEventIteratorSuite) TestIteratorError() {
filterType := filterpb.HistoryEventFilterType_AllEvent
request1 := getGetWorkflowExecutionHistoryRequest(filterType)
response1 := &workflowservice.GetWorkflowExecutionHistoryResponse{
Expand Down Expand Up @@ -804,33 +803,6 @@ func (s *workflowClientTestSuite) TestSignalWithStartWorkflow() {
s.Equal(createResponse.GetRunId(), resp.RunID)
}

func (s *workflowClientTestSuite) TestSignalWithStartWorkflow_Error() {
signalName := "my signal"
signalInput := []byte("my signal input")
options := StartWorkflowOptions{}

resp, err := s.client.SignalWithStartWorkflow(context.Background(), workflowID, signalName, signalInput,
options, workflowType)
s.Equal(errors.New("missing TaskList"), err)
s.Nil(resp)

options.TaskList = tasklist
resp, err = s.client.SignalWithStartWorkflow(context.Background(), workflowID, signalName, signalInput,
options, workflowType)
s.NotNil(err)
s.Nil(resp)

options.ExecutionStartToCloseTimeout = timeoutInSeconds
createResponse := &workflowservice.SignalWithStartWorkflowExecutionResponse{
RunId: runID,
}
s.service.EXPECT().SignalWithStartWorkflowExecution(gomock.Any(), gomock.Any(), gomock.Any()).Return(createResponse, nil)
resp, err = s.client.SignalWithStartWorkflow(context.Background(), workflowID, signalName, signalInput,
options, workflowType)
s.Nil(err)
s.Equal(createResponse.GetRunId(), resp.RunID)
}

func (s *workflowClientTestSuite) TestStartWorkflow() {
client, ok := s.client.(*WorkflowClient)
s.True(ok)
Expand Down
Loading

0 comments on commit 9592ebc

Please sign in to comment.