diff --git a/common/persistence/cassandra/errors.go b/common/persistence/cassandra/errors.go index fc452bf46b7..592fdfdce21 100644 --- a/common/persistence/cassandra/errors.go +++ b/common/persistence/cassandra/errors.go @@ -203,7 +203,7 @@ func extractCurrentWorkflowConflictError( binary, _ := conflictRecord["execution_state"].([]byte) encoding, _ := conflictRecord["execution_state_encoding"].(string) executionState := &persistencespb.WorkflowExecutionState{} - if state, err := serialization.WorkflowExecutionStateFromBlob(p.NewDataBlob(binary, encoding)); err == nil { + if state, err := serialization.DefaultDecoder.WorkflowExecutionStateFromBlob(p.NewDataBlob(binary, encoding)); err == nil { executionState = state } // if err != nil, this means execution state cannot be parsed, just use default values diff --git a/common/persistence/cassandra/errors_test.go b/common/persistence/cassandra/errors_test.go index 2977f56a3b0..aa19f30036f 100644 --- a/common/persistence/cassandra/errors_test.go +++ b/common/persistence/cassandra/errors_test.go @@ -228,7 +228,7 @@ func (s *cassandraErrorsSuite) TestExtractCurrentWorkflowConflictError_Success() }, }, } - blob, err := serialization.WorkflowExecutionStateToBlob(workflowState) + blob, err := serialization.NewSerializer().WorkflowExecutionStateToBlob(workflowState) lastWriteVersion := rand.Int63() s.NoError(err) t := rowTypeExecution diff --git a/common/persistence/cassandra/execution_store.go b/common/persistence/cassandra/execution_store.go index dce70942eb8..ae43bb49b54 100644 --- a/common/persistence/cassandra/execution_store.go +++ b/common/persistence/cassandra/execution_store.go @@ -6,6 +6,7 @@ import ( p "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/gocql" + "go.temporal.io/server/common/persistence/serialization" ) // Guidelines for creating new special UUID constants @@ -84,11 +85,11 @@ type ( var _ p.ExecutionStore = (*ExecutionStore)(nil) -func NewExecutionStore(session gocql.Session) *ExecutionStore { +func NewExecutionStore(session gocql.Session, serializer serialization.Serializer) *ExecutionStore { return &ExecutionStore{ - HistoryStore: NewHistoryStore(session), - MutableStateStore: NewMutableStateStore(session), - MutableStateTaskStore: NewMutableStateTaskStore(session), + HistoryStore: NewHistoryStore(session, serializer), + MutableStateStore: NewMutableStateStore(session, serializer), + MutableStateTaskStore: NewMutableStateTaskStore(session, serializer), } } diff --git a/common/persistence/cassandra/factory.go b/common/persistence/cassandra/factory.go index 75bef7ff2db..76b21916200 100644 --- a/common/persistence/cassandra/factory.go +++ b/common/persistence/cassandra/factory.go @@ -10,6 +10,7 @@ import ( "go.temporal.io/server/common/metrics" p "go.temporal.io/server/common/persistence" commongocql "go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/gocql" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/resolver" ) @@ -21,6 +22,7 @@ type ( clusterName string logger log.Logger session commongocql.Session + serializer serialization.Serializer } ) @@ -32,6 +34,7 @@ func NewFactory( clusterName string, logger log.Logger, metricsHandler metrics.Handler, + serializer serialization.Serializer, ) *Factory { session, err := commongocql.NewSession( func() (*gocql.ClusterConfig, error) { @@ -43,7 +46,13 @@ func NewFactory( if err != nil { logger.Fatal("unable to initialize cassandra session", tag.Error(err)) } - return NewFactoryFromSession(cfg, clusterName, logger, session) + return NewFactoryFromSession( + cfg, + clusterName, + logger, + session, + serializer, + ) } // NewFactoryFromSession returns an instance of a factory object from the given session. @@ -52,12 +61,14 @@ func NewFactoryFromSession( clusterName string, logger log.Logger, session commongocql.Session, + serializer serialization.Serializer, ) *Factory { return &Factory{ cfg: cfg, clusterName: clusterName, logger: logger, session: session, + serializer: serializer, } } @@ -88,7 +99,7 @@ func (f *Factory) NewClusterMetadataStore() (p.ClusterMetadataStore, error) { // NewExecutionStore returns a new ExecutionStore. func (f *Factory) NewExecutionStore() (p.ExecutionStore, error) { - return NewExecutionStore(f.session), nil + return NewExecutionStore(f.session, f.serializer), nil } // NewQueue returns a new queue backed by cassandra diff --git a/common/persistence/cassandra/history_store.go b/common/persistence/cassandra/history_store.go index 14d57e4a758..e2a5abfd231 100644 --- a/common/persistence/cassandra/history_store.go +++ b/common/persistence/cassandra/history_store.go @@ -7,6 +7,7 @@ import ( "go.temporal.io/api/serviceerror" p "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/gocql" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/primitives" ) @@ -44,13 +45,17 @@ const ( type ( HistoryStore struct { Session gocql.Session - p.HistoryBranchUtilImpl + p.HistoryBranchUtil } ) -func NewHistoryStore(session gocql.Session) *HistoryStore { +func NewHistoryStore( + session gocql.Session, + serializer serialization.Serializer, +) *HistoryStore { return &HistoryStore{ - Session: session, + Session: session, + HistoryBranchUtil: p.NewHistoryBranchUtil(serializer), } } @@ -137,7 +142,7 @@ func (h *HistoryStore) ReadHistoryBranch( ctx context.Context, request *p.InternalReadHistoryBranchRequest, ) (*p.InternalReadHistoryBranchResponse, error) { - branch, err := h.GetHistoryBranchUtil().ParseHistoryBranchInfo(request.BranchToken) + branch, err := h.ParseHistoryBranchInfo(request.BranchToken) if err != nil { return nil, err } @@ -346,7 +351,7 @@ func (h *HistoryStore) GetHistoryTreeContainingBranch( request *p.InternalGetHistoryTreeContainingBranchRequest, ) (*p.InternalGetHistoryTreeContainingBranchResponse, error) { - branch, err := h.GetHistoryBranchUtil().ParseHistoryBranchInfo(request.BranchToken) + branch, err := h.ParseHistoryBranchInfo(request.BranchToken) if err != nil { return nil, err } @@ -391,6 +396,10 @@ func (h *HistoryStore) GetHistoryTreeContainingBranch( }, nil } +func (h *HistoryStore) GetHistoryBranchUtil() p.HistoryBranchUtil { + return h.HistoryBranchUtil +} + func convertHistoryNode( message map[string]interface{}, ) p.InternalHistoryNode { diff --git a/common/persistence/cassandra/mutable_state_store.go b/common/persistence/cassandra/mutable_state_store.go index f53bf9c9de9..30d7387ef97 100644 --- a/common/persistence/cassandra/mutable_state_store.go +++ b/common/persistence/cassandra/mutable_state_store.go @@ -363,13 +363,15 @@ const ( type ( MutableStateStore struct { - Session gocql.Session + Session gocql.Session + serializer serialization.Serializer } ) -func NewMutableStateStore(session gocql.Session) *MutableStateStore { +func NewMutableStateStore(session gocql.Session, serializer serialization.Serializer) *MutableStateStore { return &MutableStateStore{ - Session: session, + Session: session, + serializer: serializer, } } @@ -648,7 +650,7 @@ func (d *MutableStateStore) UpdateWorkflowExecution( lastWriteVersion := updateWorkflow.LastWriteVersion // TODO: double encoding execution state? already in updateWorkflow.ExecutionStateBlob - executionStateDatablob, err := serialization.WorkflowExecutionStateToBlob(updateWorkflow.ExecutionState) + executionStateDatablob, err := d.serializer.WorkflowExecutionStateToBlob(updateWorkflow.ExecutionState) if err != nil { return err } @@ -962,7 +964,7 @@ func (d *MutableStateStore) GetCurrentExecution( } // TODO: fix blob ExecutionState in storage should not be a blob. - executionState, err := serialization.WorkflowExecutionStateFromBlob(executionStateBlob) + executionState, err := d.serializer.WorkflowExecutionStateFromBlob(executionStateBlob) if err != nil { return nil, err } diff --git a/common/persistence/cassandra/mutable_state_task_store.go b/common/persistence/cassandra/mutable_state_task_store.go index de6e7a665a0..05fed85160e 100644 --- a/common/persistence/cassandra/mutable_state_task_store.go +++ b/common/persistence/cassandra/mutable_state_task_store.go @@ -162,13 +162,15 @@ const ( type ( MutableStateTaskStore struct { - Session gocql.Session + Session gocql.Session + serializer serialization.Serializer } ) -func NewMutableStateTaskStore(session gocql.Session) *MutableStateTaskStore { +func NewMutableStateTaskStore(session gocql.Session, serializer serialization.Serializer) *MutableStateTaskStore { return &MutableStateTaskStore{ - Session: session, + Session: session, + serializer: serializer, } } @@ -503,7 +505,7 @@ func (d *MutableStateTaskStore) PutReplicationTaskToDLQ( request *p.PutReplicationTaskToDLQRequest, ) error { task := request.TaskInfo - datablob, err := serialization.ReplicationTaskInfoToBlob(task) + datablob, err := d.serializer.ReplicationTaskInfoToBlob(task) if err != nil { return gocql.ConvertError("PutReplicationTaskToDLQ", err) } diff --git a/common/persistence/cassandra/queue_store.go b/common/persistence/cassandra/queue_store.go index 7ab79d38332..0f6ba96d64d 100644 --- a/common/persistence/cassandra/queue_store.go +++ b/common/persistence/cassandra/queue_store.go @@ -30,9 +30,10 @@ const ( type ( QueueStore struct { - queueType persistence.QueueType - session gocql.Session - logger log.Logger + queueType persistence.QueueType + session gocql.Session + logger log.Logger + serializer serialization.Serializer } ) @@ -42,9 +43,10 @@ func NewQueueStore( logger log.Logger, ) (persistence.Queue, error) { return &QueueStore{ - queueType: queueType, - session: session, - logger: logger, + queueType: queueType, + session: session, + logger: logger, + serializer: serialization.NewSerializer(), }, nil } @@ -300,7 +302,7 @@ func (q *QueueStore) getQueueMetadata( return nil, err } - return convertQueueMetadata(message) + return convertQueueMetadata(message, q.serializer) } func (q *QueueStore) updateAckLevel( @@ -310,7 +312,7 @@ func (q *QueueStore) updateAckLevel( ) error { // TODO: remove this once cluster_ack_level is removed from DB - metadataStruct, err := serialization.QueueMetadataFromBlob(metadata.Blob) + metadataStruct, err := q.serializer.QueueMetadataFromBlob(metadata.Blob) if err != nil { return gocql.ConvertError("updateAckLevel", err) } @@ -385,6 +387,7 @@ func convertQueueMessage( func convertQueueMetadata( message map[string]interface{}, + serializer serialization.Serializer, ) (*persistence.InternalQueueMetadata, error) { metadata := &persistence.InternalQueueMetadata{ @@ -394,7 +397,7 @@ func convertQueueMetadata( if ok { clusterAckLevel := message["cluster_ack_level"].(map[string]int64) // TODO: remove this once we remove cluster_ack_level from DB. - blob, err := serialization.QueueMetadataToBlob(&persistencespb.QueueMetadata{ClusterAckLevels: clusterAckLevel}) + blob, err := serializer.QueueMetadataToBlob(&persistencespb.QueueMetadata{ClusterAckLevels: clusterAckLevel}) if err != nil { return nil, err } diff --git a/common/persistence/client/factory.go b/common/persistence/client/factory.go index f1e58fb82aa..8ff531f845e 100644 --- a/common/persistence/client/factory.go +++ b/common/persistence/client/factory.go @@ -193,7 +193,7 @@ func (f *factoryImpl) NewExecutionManager() (persistence.ExecutionManager, error return nil, err } - result := persistence.NewExecutionManager(store, f.serializer, f.eventBlobCache, f.logger, f.config.TransactionSizeLimit) + result := persistence.NewExecutionManager(store, f.serializer, serialization.NewTaskSerializer(f.serializer), f.eventBlobCache, f.logger, f.config.TransactionSizeLimit) if f.systemRateLimiter != nil && f.namespaceRateLimiter != nil { result = persistence.NewExecutionPersistenceRateLimitedClient(result, f.systemRateLimiter, f.namespaceRateLimiter, f.shardRateLimiter, f.logger) } @@ -225,7 +225,7 @@ func (f *factoryImpl) NewHistoryTaskQueueManager() (persistence.HistoryTaskQueue if err != nil { return nil, err } - return persistence.NewHistoryTaskQueueManager(q, serialization.NewSerializer()), nil + return persistence.NewHistoryTaskQueueManager(q, f.serializer), nil } func (f *factoryImpl) NewNexusEndpointManager() (persistence.NexusEndpointManager, error) { diff --git a/common/persistence/client/fx.go b/common/persistence/client/fx.go index 7723119f9b6..5069bd7be6b 100644 --- a/common/persistence/client/fx.go +++ b/common/persistence/client/fx.go @@ -52,6 +52,7 @@ type ( Logger log.Logger HealthSignals persistence.HealthSignalAggregator DynamicRateLimitingParams DynamicRateLimitingParams + Serializer serialization.Serializer } FactoryProviderFn func(NewFactoryParams) Factory @@ -83,6 +84,7 @@ func ClusterNameProvider(config *cluster.Config) ClusterName { func EventBlobCacheProvider( dc *dynamicconfig.Collection, logger log.Logger, + serializer serialization.Serializer, ) persistence.XDCCache { return persistence.NewEventsBlobCache( dynamicconfig.XDCCacheMaxSizeBytes.Get(dc)(), @@ -128,7 +130,7 @@ func FactoryProvider( systemRequestRateLimiter, namespaceRequestRateLimiter, shardRequestRateLimiter, - serialization.NewSerializer(), + params.Serializer, params.EventBlobCache, string(params.ClusterName), params.MetricsHandler, @@ -163,14 +165,15 @@ func DataStoreFactoryProvider( logger log.Logger, metricsHandler metrics.Handler, tracerProvider trace.TracerProvider, + serializer serialization.Serializer, ) persistence.DataStoreFactory { var dataStoreFactory persistence.DataStoreFactory defaultStoreCfg := cfg.DataStores[cfg.DefaultStore] switch { case defaultStoreCfg.Cassandra != nil: - dataStoreFactory = cassandra.NewFactory(*defaultStoreCfg.Cassandra, r, string(clusterName), logger, metricsHandler) + dataStoreFactory = cassandra.NewFactory(*defaultStoreCfg.Cassandra, r, string(clusterName), logger, metricsHandler, serializer) case defaultStoreCfg.SQL != nil: - dataStoreFactory = sql.NewFactory(*defaultStoreCfg.SQL, r, string(clusterName), logger, metricsHandler) + dataStoreFactory = sql.NewFactory(*defaultStoreCfg.SQL, r, string(clusterName), serializer, logger, metricsHandler) case defaultStoreCfg.CustomDataStoreConfig != nil: dataStoreFactory = abstractDataStoreFactory.NewFactory(*defaultStoreCfg.CustomDataStoreConfig, r, string(clusterName), logger, metricsHandler) default: diff --git a/common/persistence/data_interfaces.go b/common/persistence/data_interfaces.go index e8fe14d32d2..284bee7d614 100644 --- a/common/persistence/data_interfaces.go +++ b/common/persistence/data_interfaces.go @@ -1231,8 +1231,9 @@ type ( } HistoryTaskQueueManagerImpl struct { - queue QueueV2 - serializer serialization.Serializer + queue QueueV2 + serializer serialization.Serializer + taskSerializer serialization.TaskSerializer } // QueueKey identifies a history task queue. It is converted to a queue name using the GetQueueName method. diff --git a/common/persistence/execution_manager.go b/common/persistence/execution_manager.go index 391ac30f106..91a2a1c01d8 100644 --- a/common/persistence/execution_manager.go +++ b/common/persistence/execution_manager.go @@ -26,6 +26,7 @@ type ( // executionManagerImpl implements ExecutionManager based on ExecutionStore, statsComputer and Serializer executionManagerImpl struct { serializer serialization.Serializer + taskSerializer serialization.TaskSerializer eventBlobCache XDCCache persistence ExecutionStore logger log.Logger @@ -40,12 +41,14 @@ var _ ExecutionManager = (*executionManagerImpl)(nil) func NewExecutionManager( persistence ExecutionStore, serializer serialization.Serializer, + taskSerializer serialization.TaskSerializer, eventBlobCache XDCCache, logger log.Logger, transactionSizeLimit dynamicconfig.IntPropertyFn, ) ExecutionManager { return &executionManagerImpl{ serializer: serializer, + taskSerializer: taskSerializer, eventBlobCache: eventBlobCache, persistence: persistence, logger: logger, @@ -528,7 +531,7 @@ func (m *executionManagerImpl) SerializeWorkflowMutation( // unexport input *WorkflowMutation, ) (*InternalWorkflowMutation, error) { - tasks, err := serializeTasks(m.serializer, input.Tasks) + serializedTasks, err := serializeTasks(m.taskSerializer, input.Tasks) if err != nil { return nil, err } @@ -565,7 +568,7 @@ func (m *executionManagerImpl) SerializeWorkflowMutation( // unexport ExecutionInfo: input.ExecutionInfo, ExecutionState: input.ExecutionState, - Tasks: tasks, + Tasks: serializedTasks, Condition: input.Condition, DBRecordVersion: input.DBRecordVersion, @@ -649,8 +652,7 @@ func (m *executionManagerImpl) SerializeWorkflowMutation( // unexport func (m *executionManagerImpl) SerializeWorkflowSnapshot( // unexport input *WorkflowSnapshot, ) (*InternalWorkflowSnapshot, error) { - - tasks, err := serializeTasks(m.serializer, input.Tasks) + serializedTasks, err := serializeTasks(m.taskSerializer, input.Tasks) if err != nil { return nil, err } @@ -671,7 +673,7 @@ func (m *executionManagerImpl) SerializeWorkflowSnapshot( // unexport ExecutionState: input.ExecutionState, SignalRequestedIDs: make(map[string]struct{}), - Tasks: tasks, + Tasks: serializedTasks, Condition: input.Condition, DBRecordVersion: input.DBRecordVersion, @@ -807,7 +809,7 @@ func (m *executionManagerImpl) AddHistoryTasks( ctx context.Context, input *AddHistoryTasksRequest, ) error { - tasks, err := serializeTasks(m.serializer, input.Tasks) + serializedTasks, err := serializeTasks(m.taskSerializer, input.Tasks) if err != nil { return err } @@ -819,7 +821,7 @@ func (m *executionManagerImpl) AddHistoryTasks( NamespaceID: input.NamespaceID, WorkflowID: input.WorkflowID, - Tasks: tasks, + Tasks: serializedTasks, }) } @@ -842,7 +844,7 @@ func (m *executionManagerImpl) GetHistoryTasks( historyTasks := make([]tasks.Task, 0, len(resp.Tasks)) for _, internalTask := range resp.Tasks { - task, err := m.serializer.DeserializeTask(request.TaskCategory, internalTask.Blob) + task, err := m.taskSerializer.DeserializeTask(request.TaskCategory, internalTask.Blob) if err != nil { return nil, err } @@ -903,7 +905,7 @@ func (m *executionManagerImpl) GetReplicationTasksFromDLQ( dlqTasks := make([]tasks.Task, 0, len(resp.Tasks)) for i := range resp.Tasks { internalTask := resp.Tasks[i] - task, err := m.serializer.DeserializeTask(category, internalTask.Blob) + task, err := m.taskSerializer.DeserializeTask(category, internalTask.Blob) if err != nil { return nil, err } @@ -1139,14 +1141,14 @@ func getCurrentBranchLastWriteVersion( } func serializeTasks( - serializer serialization.Serializer, + taskSerializer serialization.TaskSerializer, inputTasks map[tasks.Category][]tasks.Task, ) (map[tasks.Category][]InternalHistoryTask, error) { outputTasks := make(map[tasks.Category][]InternalHistoryTask) for category, tasks := range inputTasks { serializedTasks := make([]InternalHistoryTask, 0, len(tasks)) for _, task := range tasks { - blob, err := serializer.SerializeTask(task) + blob, err := taskSerializer.SerializeTask(task) if err != nil { return nil, err } diff --git a/common/persistence/history_branch_util.go b/common/persistence/history_branch_util.go index 48814988067..1b8dd589166 100644 --- a/common/persistence/history_branch_util.go +++ b/common/persistence/history_branch_util.go @@ -5,8 +5,6 @@ package persistence import ( "time" - commonpb "go.temporal.io/api/common/v1" - enumspb "go.temporal.io/api/enums/v1" persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/primitives" @@ -38,9 +36,16 @@ type ( } HistoryBranchUtilImpl struct { + serializer serialization.Serializer } ) +func NewHistoryBranchUtil(serializer serialization.Serializer) *HistoryBranchUtilImpl { + return &HistoryBranchUtilImpl{ + serializer: serializer, + } +} + func (u *HistoryBranchUtilImpl) NewHistoryBranch( _ string, // namespaceID _ string, // workflowID @@ -63,7 +68,7 @@ func (u *HistoryBranchUtilImpl) NewHistoryBranch( BranchId: id, Ancestors: ancestors, } - data, err := serialization.HistoryBranchToBlob(bi) + data, err := u.serializer.HistoryBranchToBlob(bi) if err != nil { return nil, err } @@ -73,7 +78,7 @@ func (u *HistoryBranchUtilImpl) NewHistoryBranch( func (u *HistoryBranchUtilImpl) ParseHistoryBranchInfo( branchToken []byte, ) (*persistencespb.HistoryBranch, error) { - return serialization.HistoryBranchFromBlob(&commonpb.DataBlob{Data: branchToken, EncodingType: enumspb.ENCODING_TYPE_PROTO3}) + return u.serializer.HistoryBranchFromBlob(branchToken) } func (u *HistoryBranchUtilImpl) UpdateHistoryBranchInfo( @@ -81,7 +86,7 @@ func (u *HistoryBranchUtilImpl) UpdateHistoryBranchInfo( branchInfo *persistencespb.HistoryBranch, runID string, ) ([]byte, error) { - bi, err := serialization.HistoryBranchFromBlob(&commonpb.DataBlob{Data: branchToken, EncodingType: enumspb.ENCODING_TYPE_PROTO3}) + bi, err := u.serializer.HistoryBranchFromBlob(branchToken) if err != nil { return nil, err } @@ -89,13 +94,9 @@ func (u *HistoryBranchUtilImpl) UpdateHistoryBranchInfo( bi.BranchId = branchInfo.BranchId bi.Ancestors = branchInfo.Ancestors - blob, err := serialization.HistoryBranchToBlob(bi) + blob, err := u.serializer.HistoryBranchToBlob(bi) if err != nil { return nil, err } return blob.Data, nil } - -func (u *HistoryBranchUtilImpl) GetHistoryBranchUtil() HistoryBranchUtil { - return u -} diff --git a/common/persistence/history_branch_util_test.go b/common/persistence/history_branch_util_test.go index 23f2701b043..39071e4366a 100644 --- a/common/persistence/history_branch_util_test.go +++ b/common/persistence/history_branch_util_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" persistencespb "go.temporal.io/server/api/persistence/v1" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/primitives" ) @@ -35,7 +36,7 @@ func (s *historyBranchUtilSuite) TearDownTest() { } func (s *historyBranchUtilSuite) TestHistoryBranchUtil() { - var historyBranchUtil HistoryBranchUtil = &HistoryBranchUtilImpl{} + var historyBranchUtil HistoryBranchUtil = NewHistoryBranchUtil(serialization.NewSerializer()) treeID0 := primitives.NewUUID().String() branchID0 := primitives.NewUUID().String() diff --git a/common/persistence/history_task_queue_manager.go b/common/persistence/history_task_queue_manager.go index 84f2944714c..214c372441a 100644 --- a/common/persistence/history_task_queue_manager.go +++ b/common/persistence/history_task_queue_manager.go @@ -48,8 +48,9 @@ var ( func NewHistoryTaskQueueManager(queue QueueV2, serializer serialization.Serializer) *HistoryTaskQueueManagerImpl { return &HistoryTaskQueueManagerImpl{ - queue: queue, - serializer: serializer, + queue: queue, + serializer: serializer, + taskSerializer: serialization.NewTaskSerializer(serializer), } } @@ -60,7 +61,7 @@ func (m *HistoryTaskQueueManagerImpl) EnqueueTask( if request.Task == nil { return nil, ErrEnqueueTaskRequestTaskIsNil } - blob, err := m.serializer.SerializeTask(request.Task) + blob, err := m.taskSerializer.SerializeTask(request.Task) if err != nil { return nil, fmt.Errorf("%v: %w", ErrMsgSerializeTaskToEnqueue, err) } @@ -156,7 +157,7 @@ func (m *HistoryTaskQueueManagerImpl) ReadTasks(ctx context.Context, request *Re return nil, serialization.NewDeserializationError(enumspb.ENCODING_TYPE_PROTO3, ErrHistoryTaskBlobIsNil) } - task, err := m.serializer.DeserializeTask(request.QueueKey.Category, blob) + task, err := m.taskSerializer.DeserializeTask(request.QueueKey.Category, blob) if err != nil { return nil, fmt.Errorf("%v: %w", ErrMsgDeserializeHistoryTask, err) } diff --git a/common/persistence/persistence-tests/persistence_test_base.go b/common/persistence/persistence-tests/persistence_test_base.go index 9f29638b4a4..9eee000c5b9 100644 --- a/common/persistence/persistence-tests/persistence_test_base.go +++ b/common/persistence/persistence-tests/persistence_test_base.go @@ -191,6 +191,7 @@ func (s *TestBase) Setup(clusterMetadataConfig *cluster.Config) { s.DefaultTestCluster.SetupTestDatabase() cfg := s.DefaultTestCluster.Config() + serializer := serialization.NewSerializer() dataStoreFactory := client.DataStoreFactoryProvider( client.ClusterName(clusterName), resolver.NewNoopResolver(), @@ -199,6 +200,7 @@ func (s *TestBase) Setup(clusterMetadataConfig *cluster.Config) { s.Logger, metrics.NoopMetricsHandler, s.TracerProvider, + serializer, ) factory := client.NewFactory( dataStoreFactory, @@ -206,7 +208,7 @@ func (s *TestBase) Setup(clusterMetadataConfig *cluster.Config) { s.PersistenceRateLimiter, quotas.NoopRequestRateLimiter, quotas.NoopRequestRateLimiter, - serialization.NewSerializer(), + serializer, nil, clusterName, metrics.NoopMetricsHandler, diff --git a/common/persistence/serialization/blob.go b/common/persistence/serialization/blob.go deleted file mode 100644 index 9ba14fe2fd9..00000000000 --- a/common/persistence/serialization/blob.go +++ /dev/null @@ -1,108 +0,0 @@ -package serialization - -import ( - commonpb "go.temporal.io/api/common/v1" - enumspb "go.temporal.io/api/enums/v1" - persistencespb "go.temporal.io/server/api/persistence/v1" - "go.temporal.io/server/common" -) - -func HistoryBranchToBlob(info *persistencespb.HistoryBranch) (*commonpb.DataBlob, error) { - return ProtoEncode(info) -} - -func HistoryBranchFromBlob(data *commonpb.DataBlob) (*persistencespb.HistoryBranch, error) { - result := &persistencespb.HistoryBranch{} - return result, Decode(data, result) -} - -func WorkflowExecutionStateToBlob(info *persistencespb.WorkflowExecutionState) (*commonpb.DataBlob, error) { - return ProtoEncode(info) -} - -func WorkflowExecutionStateFromBlob(data *commonpb.DataBlob) (*persistencespb.WorkflowExecutionState, error) { - result := &persistencespb.WorkflowExecutionState{} - if err := Decode(data, result); err != nil { - return nil, err - } - // Initialize the WorkflowExecutionStateDetails for old records. - if result.RequestIds == nil { - result.RequestIds = make(map[string]*persistencespb.RequestIDInfo, 1) - } - if result.CreateRequestId != "" && result.RequestIds[result.CreateRequestId] == nil { - result.RequestIds[result.CreateRequestId] = &persistencespb.RequestIDInfo{ - EventType: enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED, - EventId: common.FirstEventID, - } - } - return result, nil -} - -func TransferTaskInfoToBlob(info *persistencespb.TransferTaskInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) -} - -func TransferTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.TransferTaskInfo, error) { - result := &persistencespb.TransferTaskInfo{} - return result, Decode(data, result) -} - -func TimerTaskInfoToBlob(info *persistencespb.TimerTaskInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) -} - -func TimerTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.TimerTaskInfo, error) { - result := &persistencespb.TimerTaskInfo{} - return result, Decode(data, result) -} - -func ReplicationTaskInfoToBlob(info *persistencespb.ReplicationTaskInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) -} - -func ReplicationTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ReplicationTaskInfo, error) { - result := &persistencespb.ReplicationTaskInfo{} - return result, Decode(data, result) -} - -func VisibilityTaskInfoToBlob(info *persistencespb.VisibilityTaskInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) -} - -func VisibilityTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.VisibilityTaskInfo, error) { - result := &persistencespb.VisibilityTaskInfo{} - return result, Decode(data, result) -} - -func ArchivalTaskInfoToBlob(info *persistencespb.ArchivalTaskInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) -} - -func ArchivalTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ArchivalTaskInfo, error) { - result := &persistencespb.ArchivalTaskInfo{} - return result, Decode(data, result) -} - -func OutboundTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.OutboundTaskInfo, error) { - result := &persistencespb.OutboundTaskInfo{} - return result, Decode(data, result) -} - -func QueueMetadataToBlob(metadata *persistencespb.QueueMetadata) (*commonpb.DataBlob, error) { - // TODO change ENCODING_TYPE_JSON to ENCODING_TYPE_PROTO3 - return encodeBlob(metadata, enumspb.ENCODING_TYPE_JSON) -} - -func QueueMetadataFromBlob(data *commonpb.DataBlob) (*persistencespb.QueueMetadata, error) { - result := &persistencespb.QueueMetadata{} - return result, Decode(data, result) -} - -func QueueStateToBlob(info *persistencespb.QueueState) (*commonpb.DataBlob, error) { - return ProtoEncode(info) -} - -func QueueStateFromBlob(data *commonpb.DataBlob) (*persistencespb.QueueState, error) { - result := &persistencespb.QueueState{} - return result, Decode(data, result) -} diff --git a/common/persistence/serialization/codec.go b/common/persistence/serialization/codec.go index 16c640ecd59..b34581fb6bc 100644 --- a/common/persistence/serialization/codec.go +++ b/common/persistence/serialization/codec.go @@ -2,6 +2,9 @@ package serialization import ( "errors" + "fmt" + "os" + "strings" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" @@ -9,6 +12,33 @@ import ( "google.golang.org/protobuf/proto" ) +// SerializerDataEncodingEnvVar controls which codec is used for encoding DataBlobs. +// +// Currently supported values (case-insensitive): +// - "json" +// - "proto3" +// +// Decoding always support all encodings regardless of this setting. +// +// WARNING: This environment variable should only be used for testing; and never set it in production. +const SerializerDataEncodingEnvVar = "TEMPORAL_TEST_DATA_ENCODING" + +// EncodingTypeFromEnv returns an EncodingType based on the environment variable `TEMPORAL_TEST_DATA_ENCODING`. +// It defaults to "ENCODING_TYPE_PROTO3" codec if the environment variable is not set. +func EncodingTypeFromEnv() enumspb.EncodingType { + codecType := os.Getenv(SerializerDataEncodingEnvVar) + switch strings.ToLower(codecType) { + case "", "json": + return enumspb.ENCODING_TYPE_JSON + case "proto3": + return enumspb.ENCODING_TYPE_PROTO3 + default: + //nolint:forbidigo // should fail fast and hard if used incorrectly + panic(fmt.Sprintf("unknown codec %q for environment variable %s", codecType, SerializerDataEncodingEnvVar)) + } +} + +// ProtoEncode is kept for backward compatibility. func ProtoEncode(m proto.Message) (*commonpb.DataBlob, error) { return encodeBlob(m, enumspb.ENCODING_TYPE_PROTO3) } diff --git a/common/persistence/serialization/serializer.go b/common/persistence/serialization/serializer.go index 6d99efb45ff..5e3ed8394d3 100644 --- a/common/persistence/serialization/serializer.go +++ b/common/persistence/serialization/serializer.go @@ -7,98 +7,95 @@ import ( commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" historypb "go.temporal.io/api/history/v1" + "go.temporal.io/api/temporalproto" enumsspb "go.temporal.io/server/api/enums/v1" historyspb "go.temporal.io/server/api/history/v1" persistencespb "go.temporal.io/server/api/persistence/v1" replicationspb "go.temporal.io/server/api/replication/v1" - "go.temporal.io/server/service/history/tasks" + "go.temporal.io/server/common" "google.golang.org/protobuf/proto" ) +// DefaultDecoder is here for convenience to skip the need to create a new Serializer when only decodig is needed. +// It does not need an encoding type; as it will use the one defined in the DataBlob. +var DefaultDecoder Decoder = NewSerializerWithEncoding(enumspb.ENCODING_TYPE_UNSPECIFIED) + type ( - // Serializer is used by persistence to serialize/deserialize objects - // It will only be used inside persistence, so that serialize/deserialize is transparent for application - Serializer interface { + // Encoder is used to encode objects to DataBlobs. + Encoder interface { SerializeEvents(batch []*historypb.HistoryEvent) (*commonpb.DataBlob, error) - DeserializeEvents(data *commonpb.DataBlob) ([]*historypb.HistoryEvent, error) - SerializeEvent(event *historypb.HistoryEvent) (*commonpb.DataBlob, error) - DeserializeEvent(data *commonpb.DataBlob) (*historypb.HistoryEvent, error) - DeserializeStrippedEvents(data *commonpb.DataBlob) ([]*historyspb.StrippedHistoryEvent, error) - SerializeClusterMetadata(icm *persistencespb.ClusterMetadata) (*commonpb.DataBlob, error) - DeserializeClusterMetadata(data *commonpb.DataBlob) (*persistencespb.ClusterMetadata, error) - ShardInfoToBlob(info *persistencespb.ShardInfo) (*commonpb.DataBlob, error) - ShardInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ShardInfo, error) - NamespaceDetailToBlob(info *persistencespb.NamespaceDetail) (*commonpb.DataBlob, error) - NamespaceDetailFromBlob(data *commonpb.DataBlob) (*persistencespb.NamespaceDetail, error) - HistoryTreeInfoToBlob(info *persistencespb.HistoryTreeInfo) (*commonpb.DataBlob, error) - HistoryTreeInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.HistoryTreeInfo, error) - HistoryBranchToBlob(info *persistencespb.HistoryBranch) (*commonpb.DataBlob, error) - HistoryBranchFromBlob(data *commonpb.DataBlob) (*persistencespb.HistoryBranch, error) - WorkflowExecutionInfoToBlob(info *persistencespb.WorkflowExecutionInfo) (*commonpb.DataBlob, error) - WorkflowExecutionInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.WorkflowExecutionInfo, error) - WorkflowExecutionStateToBlob(info *persistencespb.WorkflowExecutionState) (*commonpb.DataBlob, error) - WorkflowExecutionStateFromBlob(data *commonpb.DataBlob) (*persistencespb.WorkflowExecutionState, error) - ActivityInfoToBlob(info *persistencespb.ActivityInfo) (*commonpb.DataBlob, error) - ActivityInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ActivityInfo, error) - ChildExecutionInfoToBlob(info *persistencespb.ChildExecutionInfo) (*commonpb.DataBlob, error) - ChildExecutionInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ChildExecutionInfo, error) - SignalInfoToBlob(info *persistencespb.SignalInfo) (*commonpb.DataBlob, error) - SignalInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.SignalInfo, error) - RequestCancelInfoToBlob(info *persistencespb.RequestCancelInfo) (*commonpb.DataBlob, error) - RequestCancelInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.RequestCancelInfo, error) - TimerInfoToBlob(info *persistencespb.TimerInfo) (*commonpb.DataBlob, error) - TimerInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.TimerInfo, error) - TaskInfoToBlob(info *persistencespb.AllocatedTaskInfo) (*commonpb.DataBlob, error) - TaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.AllocatedTaskInfo, error) - TaskQueueInfoToBlob(info *persistencespb.TaskQueueInfo) (*commonpb.DataBlob, error) - TaskQueueInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.TaskQueueInfo, error) - TaskQueueUserDataToBlob(info *persistencespb.TaskQueueUserData) (*commonpb.DataBlob, error) - TaskQueueUserDataFromBlob(data *commonpb.DataBlob) (*persistencespb.TaskQueueUserData, error) - ChecksumToBlob(checksum *persistencespb.Checksum) (*commonpb.DataBlob, error) - ChecksumFromBlob(data *commonpb.DataBlob) (*persistencespb.Checksum, error) - QueueMetadataToBlob(metadata *persistencespb.QueueMetadata) (*commonpb.DataBlob, error) - QueueMetadataFromBlob(data *commonpb.DataBlob) (*persistencespb.QueueMetadata, error) - ReplicationTaskToBlob(replicationTask *replicationspb.ReplicationTask) (*commonpb.DataBlob, error) - ReplicationTaskFromBlob(data *commonpb.DataBlob) (*replicationspb.ReplicationTask, error) - // ParseReplicationTask is unique among these methods in that it does not serialize or deserialize a type to or - // from a byte array. Instead, it takes a proto and "parses" it into a more structured type. - ParseReplicationTask(replicationTask *persistencespb.ReplicationTaskInfo) (tasks.Task, error) - // ParseReplicationTaskInfo is unique among these methods in that it does not serialize or deserialize a type to or - // from a byte array. Instead, it takes a structured type and "parses" it into proto - ParseReplicationTaskInfo(task tasks.Task) (*persistencespb.ReplicationTaskInfo, error) - - SerializeTask(task tasks.Task) (*commonpb.DataBlob, error) - DeserializeTask(category tasks.Category, blob *commonpb.DataBlob) (tasks.Task, error) - NexusEndpointToBlob(endpoint *persistencespb.NexusEndpoint) (*commonpb.DataBlob, error) - NexusEndpointFromBlob(data *commonpb.DataBlob) (*persistencespb.NexusEndpoint, error) - // ChasmNodeToBlob returns a single encoded blob for the node. ChasmNodeToBlob(node *persistencespb.ChasmNode) (*commonpb.DataBlob, error) - ChasmNodeFromBlob(blob *commonpb.DataBlob) (*persistencespb.ChasmNode, error) - // ChasmNodeToBlobs returns the metadata blob first, followed by the data blob. ChasmNodeToBlobs(node *persistencespb.ChasmNode) (*commonpb.DataBlob, *commonpb.DataBlob, error) + TransferTaskInfoToBlob(info *persistencespb.TransferTaskInfo) (*commonpb.DataBlob, error) + TimerTaskInfoToBlob(info *persistencespb.TimerTaskInfo) (*commonpb.DataBlob, error) + ReplicationTaskInfoToBlob(info *persistencespb.ReplicationTaskInfo) (*commonpb.DataBlob, error) + VisibilityTaskInfoToBlob(info *persistencespb.VisibilityTaskInfo) (*commonpb.DataBlob, error) + ArchivalTaskInfoToBlob(info *persistencespb.ArchivalTaskInfo) (*commonpb.DataBlob, error) + OutboundTaskInfoToBlob(info *persistencespb.OutboundTaskInfo) (*commonpb.DataBlob, error) + QueueStateToBlob(info *persistencespb.QueueState) (*commonpb.DataBlob, error) + } + + // Decoder is used to decode DataBlobs to objects. + Decoder interface { + DeserializeEvents(data *commonpb.DataBlob) ([]*historypb.HistoryEvent, error) + DeserializeEvent(data *commonpb.DataBlob) (*historypb.HistoryEvent, error) + DeserializeStrippedEvents(data *commonpb.DataBlob) ([]*historyspb.StrippedHistoryEvent, error) + DeserializeClusterMetadata(data *commonpb.DataBlob) (*persistencespb.ClusterMetadata, error) + ShardInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ShardInfo, error) + NamespaceDetailFromBlob(data *commonpb.DataBlob) (*persistencespb.NamespaceDetail, error) + HistoryTreeInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.HistoryTreeInfo, error) + HistoryBranchFromBlob(data []byte) (*persistencespb.HistoryBranch, error) + WorkflowExecutionInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.WorkflowExecutionInfo, error) + WorkflowExecutionStateFromBlob(data *commonpb.DataBlob) (*persistencespb.WorkflowExecutionState, error) + ActivityInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ActivityInfo, error) + ChildExecutionInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ChildExecutionInfo, error) + SignalInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.SignalInfo, error) + RequestCancelInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.RequestCancelInfo, error) + TimerInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.TimerInfo, error) + TaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.AllocatedTaskInfo, error) + TaskQueueInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.TaskQueueInfo, error) + TaskQueueUserDataFromBlob(data *commonpb.DataBlob) (*persistencespb.TaskQueueUserData, error) + ChecksumFromBlob(data *commonpb.DataBlob) (*persistencespb.Checksum, error) + QueueMetadataFromBlob(data *commonpb.DataBlob) (*persistencespb.QueueMetadata, error) + ReplicationTaskFromBlob(data *commonpb.DataBlob) (*replicationspb.ReplicationTask, error) + NexusEndpointFromBlob(data *commonpb.DataBlob) (*persistencespb.NexusEndpoint, error) + ChasmNodeFromBlob(blob *commonpb.DataBlob) (*persistencespb.ChasmNode, error) ChasmNodeFromBlobs(metadata *commonpb.DataBlob, data *commonpb.DataBlob) (*persistencespb.ChasmNode, error) + TransferTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.TransferTaskInfo, error) + TimerTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.TimerTaskInfo, error) + ReplicationTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ReplicationTaskInfo, error) + VisibilityTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.VisibilityTaskInfo, error) + ArchivalTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ArchivalTaskInfo, error) + OutboundTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.OutboundTaskInfo, error) + QueueStateFromBlob(data *commonpb.DataBlob) (*persistencespb.QueueState, error) + } + + // Serializer is used to serialize and deserialize DataBlobs. + Serializer interface { + Encoder + Decoder } // SerializationError is an error type for serialization @@ -120,7 +117,7 @@ type ( } serializerImpl struct { - TaskSerializer + encodingType enumspb.EncodingType } marshaler interface { @@ -128,9 +125,15 @@ type ( } ) -// NewSerializer returns a PayloadSerializer +// NewSerializer returns a Serializer that picks the encoding type based on an environment variable. +// If none is set, it defaults to the Proto3 codec. See EncodingTypeFromEnv for details. func NewSerializer() Serializer { - return &serializerImpl{} + return NewSerializerWithEncoding(EncodingTypeFromEnv()) +} + +// NewSerializerWithEncoding returns a Serializer that uses the provided encoding type. +func NewSerializerWithEncoding(encodingType enumspb.EncodingType) Serializer { + return &serializerImpl{encodingType: encodingType} } func (t *serializerImpl) SerializeEvents(events []*historypb.HistoryEvent) (*commonpb.DataBlob, error) { @@ -146,16 +149,9 @@ func (t *serializerImpl) DeserializeEvents(data *commonpb.DataBlob) ([]*historyp } events := &historypb.History{} - var err error - switch data.EncodingType { - case enumspb.ENCODING_TYPE_PROTO3: - // Client API currently specifies encodingType on requests which span multiple of these objects - err = events.Unmarshal(data.Data) - default: - return nil, NewUnknownEncodingTypeError(data.EncodingType.String(), enumspb.ENCODING_TYPE_PROTO3) - } + err := Decode(data, events) if err != nil { - return nil, NewDeserializationError(enumspb.ENCODING_TYPE_PROTO3, err) + return nil, err } return events.Events, nil } @@ -170,7 +166,6 @@ func (t *serializerImpl) DeserializeStrippedEvents(data *commonpb.DataBlob) ([]* events := &historyspb.StrippedHistoryEvents{} var err error - //nolint:exhaustive switch data.EncodingType { case enumspb.ENCODING_TYPE_PROTO3: // Discard unknown fields to improve performance. StrippedHistoryEvents is usually deserialized from HistoryEvent @@ -178,11 +173,16 @@ func (t *serializerImpl) DeserializeStrippedEvents(data *commonpb.DataBlob) ([]* err = proto.UnmarshalOptions{ DiscardUnknown: true, }.Unmarshal(data.Data, events) + case enumspb.ENCODING_TYPE_JSON: + err = temporalproto.CustomJSONUnmarshalOptions{ + DiscardUnknown: true, + }.Unmarshal(data.Data, events) default: - return nil, NewUnknownEncodingTypeError(data.EncodingType.String(), enumspb.ENCODING_TYPE_PROTO3) + return nil, NewUnknownEncodingTypeError(data.EncodingType.String(), + enumspb.ENCODING_TYPE_PROTO3, enumspb.ENCODING_TYPE_JSON) } if err != nil { - return nil, NewDeserializationError(enumspb.ENCODING_TYPE_PROTO3, err) + return nil, NewDeserializationError(data.EncodingType, err) } return events.Events, nil } @@ -203,19 +203,11 @@ func (t *serializerImpl) DeserializeEvent(data *commonpb.DataBlob) (*historypb.H } event := &historypb.HistoryEvent{} - var err error - switch data.EncodingType { - case enumspb.ENCODING_TYPE_PROTO3: - // Client API currently specifies encodingType on requests which span multiple of these objects - err = event.Unmarshal(data.Data) - default: - return nil, NewUnknownEncodingTypeError(data.EncodingType.String(), enumspb.ENCODING_TYPE_PROTO3) - } + err := Decode(data, event) if err != nil { - return nil, NewDeserializationError(enumspb.ENCODING_TYPE_PROTO3, err) + return nil, err } - - return event, err + return event, nil } func (t *serializerImpl) SerializeClusterMetadata(cm *persistencespb.ClusterMetadata) (*commonpb.DataBlob, error) { @@ -234,41 +226,22 @@ func (t *serializerImpl) DeserializeClusterMetadata(data *commonpb.DataBlob) (*p } cm := &persistencespb.ClusterMetadata{} - var err error - switch data.EncodingType { - case enumspb.ENCODING_TYPE_PROTO3: - // Thrift == Proto for this object so that we can maintain test behavior until thrift is gone - // Client API currently specifies encodingType on requests which span multiple of these objects - err = cm.Unmarshal(data.Data) - default: - return nil, NewUnknownEncodingTypeError(data.EncodingType.String(), enumspb.ENCODING_TYPE_PROTO3) - } + err := Decode(data, cm) if err != nil { - return nil, NewSerializationError(enumspb.ENCODING_TYPE_PROTO3, err) + return nil, err } - - return cm, err + return cm, nil } -func (t *serializerImpl) serialize(p marshaler) (*commonpb.DataBlob, error) { +func (t *serializerImpl) serialize(p proto.Message) (*commonpb.DataBlob, error) { if p == nil { return nil, nil } - - data, err := p.Marshal() + blob, err := encodeBlob(p, t.encodingType) if err != nil { - return nil, NewSerializationError(enumspb.ENCODING_TYPE_PROTO3, err) - } - - // Shouldn't happen, but keeping - if data == nil { - return nil, nil + return nil, NewSerializationError(t.encodingType, err) } - - return &commonpb.DataBlob{ - Data: data, - EncodingType: enumspb.ENCODING_TYPE_PROTO3, - }, nil + return blob, nil } // NewUnknownEncodingTypeError returns a new instance of encoding type error @@ -345,7 +318,7 @@ func (e *DeserializationError) Unwrap() error { func (e *DeserializationError) IsTerminalTaskError() bool { return true } func (t *serializerImpl) ShardInfoToBlob(info *persistencespb.ShardInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) + return encodeBlob(info, t.encodingType) } func (t *serializerImpl) ShardInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ShardInfo, error) { @@ -378,7 +351,7 @@ func (t *serializerImpl) ShardInfoFromBlob(data *commonpb.DataBlob) (*persistenc } func (t *serializerImpl) NamespaceDetailToBlob(info *persistencespb.NamespaceDetail) (*commonpb.DataBlob, error) { - return ProtoEncode(info) + return encodeBlob(info, t.encodingType) } func (t *serializerImpl) NamespaceDetailFromBlob(data *commonpb.DataBlob) (*persistencespb.NamespaceDetail, error) { @@ -387,7 +360,7 @@ func (t *serializerImpl) NamespaceDetailFromBlob(data *commonpb.DataBlob) (*pers } func (t *serializerImpl) HistoryTreeInfoToBlob(info *persistencespb.HistoryTreeInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) + return encodeBlob(info, t.encodingType) } func (t *serializerImpl) HistoryTreeInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.HistoryTreeInfo, error) { @@ -396,16 +369,17 @@ func (t *serializerImpl) HistoryTreeInfoFromBlob(data *commonpb.DataBlob) (*pers } func (t *serializerImpl) HistoryBranchToBlob(info *persistencespb.HistoryBranch) (*commonpb.DataBlob, error) { - return ProtoEncode(info) + return encodeBlob(info, t.encodingType) } -func (t *serializerImpl) HistoryBranchFromBlob(data *commonpb.DataBlob) (*persistencespb.HistoryBranch, error) { +// NOTE: HistoryBranch does not have an encoding type; so we use the serializer's encoding type. +func (t *serializerImpl) HistoryBranchFromBlob(data []byte) (*persistencespb.HistoryBranch, error) { result := &persistencespb.HistoryBranch{} - return result, Decode(data, result) + return result, Decode(&commonpb.DataBlob{Data: data, EncodingType: t.encodingType}, result) } func (t *serializerImpl) WorkflowExecutionInfoToBlob(info *persistencespb.WorkflowExecutionInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) + return encodeBlob(info, t.encodingType) } func (t *serializerImpl) WorkflowExecutionInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.WorkflowExecutionInfo, error) { @@ -422,15 +396,29 @@ func (t *serializerImpl) WorkflowExecutionInfoFromBlob(data *commonpb.DataBlob) } func (t *serializerImpl) WorkflowExecutionStateToBlob(info *persistencespb.WorkflowExecutionState) (*commonpb.DataBlob, error) { - return ProtoEncode(info) + return encodeBlob(info, t.encodingType) } func (t *serializerImpl) WorkflowExecutionStateFromBlob(data *commonpb.DataBlob) (*persistencespb.WorkflowExecutionState, error) { - return WorkflowExecutionStateFromBlob(data) + result := &persistencespb.WorkflowExecutionState{} + if err := Decode(data, result); err != nil { + return nil, err + } + // Initialize the WorkflowExecutionStateDetails for old records. + if result.RequestIds == nil { + result.RequestIds = make(map[string]*persistencespb.RequestIDInfo, 1) + } + if result.CreateRequestId != "" && result.RequestIds[result.CreateRequestId] == nil { + result.RequestIds[result.CreateRequestId] = &persistencespb.RequestIDInfo{ + EventType: enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED, + EventId: common.FirstEventID, + } + } + return result, nil } func (t *serializerImpl) ActivityInfoToBlob(info *persistencespb.ActivityInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) + return encodeBlob(info, t.encodingType) } func (t *serializerImpl) ActivityInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ActivityInfo, error) { @@ -439,7 +427,7 @@ func (t *serializerImpl) ActivityInfoFromBlob(data *commonpb.DataBlob) (*persist } func (t *serializerImpl) ChildExecutionInfoToBlob(info *persistencespb.ChildExecutionInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) + return encodeBlob(info, t.encodingType) } func (t *serializerImpl) ChildExecutionInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ChildExecutionInfo, error) { @@ -448,7 +436,7 @@ func (t *serializerImpl) ChildExecutionInfoFromBlob(data *commonpb.DataBlob) (*p } func (t *serializerImpl) SignalInfoToBlob(info *persistencespb.SignalInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) + return encodeBlob(info, t.encodingType) } func (t *serializerImpl) SignalInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.SignalInfo, error) { @@ -457,7 +445,7 @@ func (t *serializerImpl) SignalInfoFromBlob(data *commonpb.DataBlob) (*persisten } func (t *serializerImpl) RequestCancelInfoToBlob(info *persistencespb.RequestCancelInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) + return encodeBlob(info, t.encodingType) } func (t *serializerImpl) RequestCancelInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.RequestCancelInfo, error) { @@ -466,7 +454,7 @@ func (t *serializerImpl) RequestCancelInfoFromBlob(data *commonpb.DataBlob) (*pe } func (t *serializerImpl) TimerInfoToBlob(info *persistencespb.TimerInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) + return encodeBlob(info, t.encodingType) } func (t *serializerImpl) TimerInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.TimerInfo, error) { @@ -475,7 +463,7 @@ func (t *serializerImpl) TimerInfoFromBlob(data *commonpb.DataBlob) (*persistenc } func (t *serializerImpl) TaskInfoToBlob(info *persistencespb.AllocatedTaskInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) + return encodeBlob(info, t.encodingType) } func (t *serializerImpl) TaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.AllocatedTaskInfo, error) { @@ -484,7 +472,7 @@ func (t *serializerImpl) TaskInfoFromBlob(data *commonpb.DataBlob) (*persistence } func (t *serializerImpl) TaskQueueInfoToBlob(info *persistencespb.TaskQueueInfo) (*commonpb.DataBlob, error) { - return ProtoEncode(info) + return encodeBlob(info, t.encodingType) } func (t *serializerImpl) TaskQueueInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.TaskQueueInfo, error) { @@ -493,7 +481,7 @@ func (t *serializerImpl) TaskQueueInfoFromBlob(data *commonpb.DataBlob) (*persis } func (t *serializerImpl) TaskQueueUserDataToBlob(data *persistencespb.TaskQueueUserData) (*commonpb.DataBlob, error) { - return ProtoEncode(data) + return encodeBlob(data, t.encodingType) } func (t *serializerImpl) TaskQueueUserDataFromBlob(data *commonpb.DataBlob) (*persistencespb.TaskQueueUserData, error) { @@ -506,7 +494,7 @@ func (t *serializerImpl) ChecksumToBlob(checksum *persistencespb.Checksum) (*com if checksum == nil { checksum = &persistencespb.Checksum{} } - return ProtoEncode(checksum) + return encodeBlob(checksum, t.encodingType) } func (t *serializerImpl) ChecksumFromBlob(data *commonpb.DataBlob) (*persistencespb.Checksum, error) { @@ -520,7 +508,8 @@ func (t *serializerImpl) ChecksumFromBlob(data *commonpb.DataBlob) (*persistence } func (t *serializerImpl) QueueMetadataToBlob(metadata *persistencespb.QueueMetadata) (*commonpb.DataBlob, error) { - return ProtoEncode(metadata) + // TODO change ENCODING_TYPE_JSON to ENCODING_TYPE_PROTO3 + return encodeBlob(metadata, enumspb.ENCODING_TYPE_JSON) } func (t *serializerImpl) QueueMetadataFromBlob(data *commonpb.DataBlob) (*persistencespb.QueueMetadata, error) { @@ -529,7 +518,7 @@ func (t *serializerImpl) QueueMetadataFromBlob(data *commonpb.DataBlob) (*persis } func (t *serializerImpl) ReplicationTaskToBlob(replicationTask *replicationspb.ReplicationTask) (*commonpb.DataBlob, error) { - return ProtoEncode(replicationTask) + return encodeBlob(replicationTask, t.encodingType) } func (t *serializerImpl) ReplicationTaskFromBlob(data *commonpb.DataBlob) (*replicationspb.ReplicationTask, error) { @@ -538,7 +527,7 @@ func (t *serializerImpl) ReplicationTaskFromBlob(data *commonpb.DataBlob) (*repl } func (t *serializerImpl) NexusEndpointToBlob(endpoint *persistencespb.NexusEndpoint) (*commonpb.DataBlob, error) { - return ProtoEncode(endpoint) + return encodeBlob(endpoint, t.encodingType) } func (t *serializerImpl) NexusEndpointFromBlob(data *commonpb.DataBlob) (*persistencespb.NexusEndpoint, error) { @@ -547,7 +536,7 @@ func (t *serializerImpl) NexusEndpointFromBlob(data *commonpb.DataBlob) (*persis } func (t *serializerImpl) ChasmNodeToBlobs(node *persistencespb.ChasmNode) (metadata *commonpb.DataBlob, nodedata *commonpb.DataBlob, retErr error) { - metadata, retErr = ProtoEncode(node.Metadata) + metadata, retErr = encodeBlob(node.Metadata, t.encodingType) if retErr != nil { return nil, nil, retErr } @@ -563,10 +552,73 @@ func (t *serializerImpl) ChasmNodeFromBlobs(metadata *commonpb.DataBlob, data *c } func (t *serializerImpl) ChasmNodeToBlob(node *persistencespb.ChasmNode) (*commonpb.DataBlob, error) { - return ProtoEncode(node) + return encodeBlob(node, t.encodingType) } func (t *serializerImpl) ChasmNodeFromBlob(blob *commonpb.DataBlob) (*persistencespb.ChasmNode, error) { result := &persistencespb.ChasmNode{} return result, Decode(blob, result) } + +func (t *serializerImpl) TransferTaskInfoToBlob(info *persistencespb.TransferTaskInfo) (*commonpb.DataBlob, error) { + return encodeBlob(info, t.encodingType) +} + +func (t *serializerImpl) TransferTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.TransferTaskInfo, error) { + result := &persistencespb.TransferTaskInfo{} + return result, Decode(data, result) +} + +func (t *serializerImpl) TimerTaskInfoToBlob(info *persistencespb.TimerTaskInfo) (*commonpb.DataBlob, error) { + return encodeBlob(info, t.encodingType) +} + +func (t *serializerImpl) TimerTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.TimerTaskInfo, error) { + result := &persistencespb.TimerTaskInfo{} + return result, Decode(data, result) +} + +func (t *serializerImpl) ReplicationTaskInfoToBlob(info *persistencespb.ReplicationTaskInfo) (*commonpb.DataBlob, error) { + return encodeBlob(info, t.encodingType) +} + +func (t *serializerImpl) ReplicationTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ReplicationTaskInfo, error) { + result := &persistencespb.ReplicationTaskInfo{} + return result, Decode(data, result) +} + +func (t *serializerImpl) VisibilityTaskInfoToBlob(info *persistencespb.VisibilityTaskInfo) (*commonpb.DataBlob, error) { + return encodeBlob(info, t.encodingType) +} + +func (t *serializerImpl) VisibilityTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.VisibilityTaskInfo, error) { + result := &persistencespb.VisibilityTaskInfo{} + return result, Decode(data, result) +} + +func (t *serializerImpl) ArchivalTaskInfoToBlob(info *persistencespb.ArchivalTaskInfo) (*commonpb.DataBlob, error) { + return encodeBlob(info, t.encodingType) +} + +func (t *serializerImpl) ArchivalTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.ArchivalTaskInfo, error) { + result := &persistencespb.ArchivalTaskInfo{} + return result, Decode(data, result) +} + +func (t *serializerImpl) OutboundTaskInfoToBlob(info *persistencespb.OutboundTaskInfo) (*commonpb.DataBlob, error) { + return encodeBlob(info, t.encodingType) +} + +func (t *serializerImpl) OutboundTaskInfoFromBlob(data *commonpb.DataBlob) (*persistencespb.OutboundTaskInfo, error) { + result := &persistencespb.OutboundTaskInfo{} + return result, Decode(data, result) +} + +func (t *serializerImpl) QueueStateToBlob(info *persistencespb.QueueState) (*commonpb.DataBlob, error) { + return encodeBlob(info, t.encodingType) +} + +func (t *serializerImpl) QueueStateFromBlob(data *commonpb.DataBlob) (*persistencespb.QueueState, error) { + result := &persistencespb.QueueState{} + return result, Decode(data, result) +} diff --git a/common/persistence/serialization/serializer_test.go b/common/persistence/serialization/serializer_test.go index 82b6055f5dd..44eb5b5694d 100644 --- a/common/persistence/serialization/serializer_test.go +++ b/common/persistence/serialization/serializer_test.go @@ -194,7 +194,7 @@ func (s *temporalSerializerSuite) TestDeserializeStrippedEvents() { // 3. Unknown encoding type s.Run("UnknownEncodingType", func() { _, err := s.serializer.DeserializeStrippedEvents(&commonpb.DataBlob{ - EncodingType: enumspb.ENCODING_TYPE_JSON, // Not handled by our switch + EncodingType: enumspb.ENCODING_TYPE_UNSPECIFIED, Data: []byte("irrelevant-data"), }) s.Error(err) diff --git a/common/persistence/serialization/task_serializer.go b/common/persistence/serialization/task_serializer.go index fc19a5f459b..643f33ea6c6 100644 --- a/common/persistence/serialization/task_serializer.go +++ b/common/persistence/serialization/task_serializer.go @@ -14,16 +14,28 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +var _ TaskSerializer = (*taskSerializer)(nil) + type ( - TaskSerializer struct { + TaskSerializer interface { + ReplicationTaskSerializer + SerializeTask(task tasks.Task) (*commonpb.DataBlob, error) + DeserializeTask(category tasks.Category, blob *commonpb.DataBlob) (tasks.Task, error) + } + ReplicationTaskSerializer interface { + SeralizeReplicationTask(task tasks.Task) (*persistencespb.ReplicationTaskInfo, error) + DeserializeReplicationTask(replicationTask *persistencespb.ReplicationTaskInfo) (tasks.Task, error) + } + taskSerializer struct { + serializer Serializer } ) -func NewTaskSerializer() *TaskSerializer { - return &TaskSerializer{} +func NewTaskSerializer(serializer Serializer) TaskSerializer { + return &taskSerializer{serializer: serializer} } -func (s *TaskSerializer) SerializeTask( +func (s *taskSerializer) SerializeTask( task tasks.Task, ) (*commonpb.DataBlob, error) { category := task.GetCategory() @@ -45,21 +57,21 @@ func (s *TaskSerializer) SerializeTask( } } -func (s *TaskSerializer) DeserializeTask( +func (s *taskSerializer) DeserializeTask( category tasks.Category, blob *commonpb.DataBlob, ) (tasks.Task, error) { switch category.ID() { case tasks.CategoryIDTransfer: - return s.deserializeTransferTasks(blob) + return s.deserializeTransferTask(blob) case tasks.CategoryIDTimer: - return s.deserializeTimerTasks(blob) + return s.deserializeTimerTask(blob) case tasks.CategoryIDVisibility: - return s.deserializeVisibilityTasks(blob) + return s.deserializeVisibilityTask(blob) case tasks.CategoryIDReplication: - return s.deserializeReplicationTasks(blob) + return s.deserializeReplicationTask(blob) case tasks.CategoryIDArchival: - return s.deserializeArchivalTasks(blob) + return s.deserializeArchivalTask(blob) case tasks.CategoryIDOutbound: return s.deserializeOutboundTask(blob) default: @@ -67,7 +79,7 @@ func (s *TaskSerializer) DeserializeTask( } } -func (s *TaskSerializer) serializeTransferTask( +func (s *taskSerializer) serializeTransferTask( task tasks.Task, ) (*commonpb.DataBlob, error) { var transferTask *persistencespb.TransferTaskInfo @@ -93,11 +105,10 @@ func (s *TaskSerializer) serializeTransferTask( default: return nil, serviceerror.NewInternalf("Unknown transfer task type: %v", task) } - - return TransferTaskInfoToBlob(transferTask) + return s.serializer.TransferTaskInfoToBlob(transferTask) } -func (s *TaskSerializer) transferChasmTaskToProto(task *tasks.ChasmTask) *persistencespb.TransferTaskInfo { +func (s *taskSerializer) transferChasmTaskToProto(task *tasks.ChasmTask) *persistencespb.TransferTaskInfo { return &persistencespb.TransferTaskInfo{ NamespaceId: task.WorkflowKey.NamespaceID, WorkflowId: task.WorkflowKey.WorkflowID, @@ -111,14 +122,13 @@ func (s *TaskSerializer) transferChasmTaskToProto(task *tasks.ChasmTask) *persis } } -func (s *TaskSerializer) deserializeTransferTasks( +func (s *taskSerializer) deserializeTransferTask( blob *commonpb.DataBlob, ) (tasks.Task, error) { - transferTask, err := TransferTaskInfoFromBlob(blob) + transferTask, err := s.serializer.TransferTaskInfoFromBlob(blob) if err != nil { return nil, err } - var task tasks.Task switch transferTask.TaskType { case enumsspb.TASK_TYPE_TRANSFER_WORKFLOW_TASK: @@ -145,7 +155,7 @@ func (s *TaskSerializer) deserializeTransferTasks( return task, nil } -func (s *TaskSerializer) transferChasmTaskFromProto(task *persistencespb.TransferTaskInfo) tasks.Task { +func (s *taskSerializer) transferChasmTaskFromProto(task *persistencespb.TransferTaskInfo) tasks.Task { return &tasks.ChasmTask{ WorkflowKey: definition.NewWorkflowKey( task.NamespaceId, @@ -159,7 +169,7 @@ func (s *TaskSerializer) transferChasmTaskFromProto(task *persistencespb.Transfe } } -func (s *TaskSerializer) serializeTimerTask( +func (s *taskSerializer) serializeTimerTask( task tasks.Task, ) (*commonpb.DataBlob, error) { var timerTask *persistencespb.TimerTaskInfo @@ -189,10 +199,10 @@ func (s *TaskSerializer) serializeTimerTask( default: return nil, serviceerror.NewInternalf("Unknown timer task type: %v", task) } - return TimerTaskInfoToBlob(timerTask) + return s.serializer.TimerTaskInfoToBlob(timerTask) } -func (s *TaskSerializer) timerChasmPureTaskToProto(task *tasks.ChasmTaskPure) *persistencespb.TimerTaskInfo { +func (s *taskSerializer) timerChasmPureTaskToProto(task *tasks.ChasmTaskPure) *persistencespb.TimerTaskInfo { return &persistencespb.TimerTaskInfo{ NamespaceId: task.NamespaceID, WorkflowId: task.WorkflowID, @@ -203,7 +213,7 @@ func (s *TaskSerializer) timerChasmPureTaskToProto(task *tasks.ChasmTaskPure) *p } } -func (s *TaskSerializer) timerChasmTaskToProto(task *tasks.ChasmTask) *persistencespb.TimerTaskInfo { +func (s *taskSerializer) timerChasmTaskToProto(task *tasks.ChasmTask) *persistencespb.TimerTaskInfo { return &persistencespb.TimerTaskInfo{ NamespaceId: task.NamespaceID, WorkflowId: task.WorkflowID, @@ -217,14 +227,13 @@ func (s *TaskSerializer) timerChasmTaskToProto(task *tasks.ChasmTask) *persisten } } -func (s *TaskSerializer) deserializeTimerTasks( +func (s *taskSerializer) deserializeTimerTask( blob *commonpb.DataBlob, ) (tasks.Task, error) { - timerTask, err := TimerTaskInfoFromBlob(blob) + timerTask, err := s.serializer.TimerTaskInfoFromBlob(blob) if err != nil { return nil, err } - var timer tasks.Task switch timerTask.TaskType { case enumsspb.TASK_TYPE_WORKFLOW_TASK_TIMEOUT: @@ -255,7 +264,7 @@ func (s *TaskSerializer) deserializeTimerTasks( return timer, nil } -func (s *TaskSerializer) timerChasmTaskFromProto(info *persistencespb.TimerTaskInfo) tasks.Task { +func (s *taskSerializer) timerChasmTaskFromProto(info *persistencespb.TimerTaskInfo) tasks.Task { return &tasks.ChasmTask{ WorkflowKey: definition.NewWorkflowKey( info.NamespaceId, @@ -269,7 +278,7 @@ func (s *TaskSerializer) timerChasmTaskFromProto(info *persistencespb.TimerTaskI } } -func (s *TaskSerializer) timerChasmPureTaskFromProto(info *persistencespb.TimerTaskInfo) tasks.Task { +func (s *taskSerializer) timerChasmPureTaskFromProto(info *persistencespb.TimerTaskInfo) tasks.Task { return &tasks.ChasmTaskPure{ WorkflowKey: definition.NewWorkflowKey( info.NamespaceId, @@ -282,7 +291,7 @@ func (s *TaskSerializer) timerChasmPureTaskFromProto(info *persistencespb.TimerT } } -func (s *TaskSerializer) serializeVisibilityTask( +func (s *taskSerializer) serializeVisibilityTask( task tasks.Task, ) (*commonpb.DataBlob, error) { var visibilityTask *persistencespb.VisibilityTaskInfo @@ -300,18 +309,16 @@ func (s *TaskSerializer) serializeVisibilityTask( default: return nil, serviceerror.NewInternalf("Unknown visibility task type: %v", task) } - - return VisibilityTaskInfoToBlob(visibilityTask) + return s.serializer.VisibilityTaskInfoToBlob(visibilityTask) } -func (s *TaskSerializer) deserializeVisibilityTasks( +func (s *taskSerializer) deserializeVisibilityTask( blob *commonpb.DataBlob, ) (tasks.Task, error) { - visibilityTask, err := VisibilityTaskInfoFromBlob(blob) + visibilityTask, err := s.serializer.VisibilityTaskInfoFromBlob(blob) if err != nil { return nil, err } - var visibility tasks.Task switch visibilityTask.TaskType { case enumsspb.TASK_TYPE_VISIBILITY_START_EXECUTION: @@ -330,28 +337,28 @@ func (s *TaskSerializer) deserializeVisibilityTasks( return visibility, nil } -func (s *TaskSerializer) serializeReplicationTask( +func (s *taskSerializer) serializeReplicationTask( task tasks.Task, ) (*commonpb.DataBlob, error) { - replicationTask, err := s.ParseReplicationTaskInfo(task) + replicationTask, err := s.SeralizeReplicationTask(task) if err != nil { return nil, err } - return ReplicationTaskInfoToBlob(replicationTask) + return s.serializer.ReplicationTaskInfoToBlob(replicationTask) } -func (s *TaskSerializer) deserializeReplicationTasks( +func (s *taskSerializer) deserializeReplicationTask( blob *commonpb.DataBlob, ) (tasks.Task, error) { - replicationTask, err := ReplicationTaskInfoFromBlob(blob) + replicationTask, err := s.serializer.ReplicationTaskInfoFromBlob(blob) if err != nil { return nil, err } - return s.ParseReplicationTask(replicationTask) + return s.DeserializeReplicationTask(replicationTask) } -func (s *TaskSerializer) ParseReplicationTask(replicationTask *persistencespb.ReplicationTaskInfo) (tasks.Task, error) { +func (s *taskSerializer) DeserializeReplicationTask(replicationTask *persistencespb.ReplicationTaskInfo) (tasks.Task, error) { switch replicationTask.TaskType { case enumsspb.TASK_TYPE_REPLICATION_SYNC_ACTIVITY: return s.replicationActivityTaskFromProto(replicationTask), nil @@ -368,7 +375,7 @@ func (s *TaskSerializer) ParseReplicationTask(replicationTask *persistencespb.Re } } -func (s *TaskSerializer) ParseReplicationTaskInfo(task tasks.Task) (*persistencespb.ReplicationTaskInfo, error) { +func (s *taskSerializer) SeralizeReplicationTask(task tasks.Task) (*persistencespb.ReplicationTaskInfo, error) { switch task := task.(type) { case *tasks.SyncActivityTask: return s.replicationActivityTaskToProto(task), nil @@ -385,7 +392,7 @@ func (s *TaskSerializer) ParseReplicationTaskInfo(task tasks.Task) (*persistence } } -func (s *TaskSerializer) serializeArchivalTask( +func (s *taskSerializer) serializeArchivalTask( task tasks.Task, ) (*commonpb.DataBlob, error) { var archivalTaskInfo *persistencespb.ArchivalTaskInfo @@ -397,13 +404,13 @@ func (s *TaskSerializer) serializeArchivalTask( "Unknown archival task type while serializing: %v", task) } - return ArchivalTaskInfoToBlob(archivalTaskInfo) + return s.serializer.ArchivalTaskInfoToBlob(archivalTaskInfo) } -func (s *TaskSerializer) deserializeArchivalTasks( +func (s *taskSerializer) deserializeArchivalTask( blob *commonpb.DataBlob, ) (tasks.Task, error) { - archivalTask, err := ArchivalTaskInfoFromBlob(blob) + archivalTask, err := s.serializer.ArchivalTaskInfoFromBlob(blob) if err != nil { return nil, err } @@ -417,7 +424,7 @@ func (s *TaskSerializer) deserializeArchivalTasks( return task, nil } -func (s *TaskSerializer) transferActivityTaskToProto( +func (s *taskSerializer) transferActivityTaskToProto( activityTask *tasks.ActivityTask, ) *persistencespb.TransferTaskInfo { return &persistencespb.TransferTaskInfo{ @@ -438,7 +445,7 @@ func (s *TaskSerializer) transferActivityTaskToProto( } } -func (s *TaskSerializer) transferActivityTaskFromProto( +func (s *taskSerializer) transferActivityTaskFromProto( activityTask *persistencespb.TransferTaskInfo, ) *tasks.ActivityTask { return &tasks.ActivityTask{ @@ -456,7 +463,7 @@ func (s *TaskSerializer) transferActivityTaskFromProto( } } -func (s *TaskSerializer) transferWorkflowTaskToProto( +func (s *taskSerializer) transferWorkflowTaskToProto( workflowTask *tasks.WorkflowTask, ) *persistencespb.TransferTaskInfo { return &persistencespb.TransferTaskInfo{ @@ -476,7 +483,7 @@ func (s *TaskSerializer) transferWorkflowTaskToProto( } } -func (s *TaskSerializer) transferWorkflowTaskFromProto( +func (s *taskSerializer) transferWorkflowTaskFromProto( workflowTask *persistencespb.TransferTaskInfo, ) *tasks.WorkflowTask { return &tasks.WorkflowTask{ @@ -493,7 +500,7 @@ func (s *TaskSerializer) transferWorkflowTaskFromProto( } } -func (s *TaskSerializer) transferRequestCancelTaskToProto( +func (s *taskSerializer) transferRequestCancelTaskToProto( requestCancelTask *tasks.CancelExecutionTask, ) *persistencespb.TransferTaskInfo { return &persistencespb.TransferTaskInfo{ @@ -513,7 +520,7 @@ func (s *TaskSerializer) transferRequestCancelTaskToProto( } } -func (s *TaskSerializer) transferRequestCancelTaskFromProto( +func (s *taskSerializer) transferRequestCancelTaskFromProto( requestCancelTask *persistencespb.TransferTaskInfo, ) *tasks.CancelExecutionTask { return &tasks.CancelExecutionTask{ @@ -533,7 +540,7 @@ func (s *TaskSerializer) transferRequestCancelTaskFromProto( } } -func (s *TaskSerializer) transferSignalTaskToProto( +func (s *taskSerializer) transferSignalTaskToProto( signalTask *tasks.SignalExecutionTask, ) *persistencespb.TransferTaskInfo { return &persistencespb.TransferTaskInfo{ @@ -553,7 +560,7 @@ func (s *TaskSerializer) transferSignalTaskToProto( } } -func (s *TaskSerializer) transferSignalTaskFromProto( +func (s *taskSerializer) transferSignalTaskFromProto( signalTask *persistencespb.TransferTaskInfo, ) *tasks.SignalExecutionTask { return &tasks.SignalExecutionTask{ @@ -573,7 +580,7 @@ func (s *TaskSerializer) transferSignalTaskFromProto( } } -func (s *TaskSerializer) transferChildWorkflowTaskToProto( +func (s *taskSerializer) transferChildWorkflowTaskToProto( childWorkflowTask *tasks.StartChildExecutionTask, ) *persistencespb.TransferTaskInfo { return &persistencespb.TransferTaskInfo{ @@ -593,7 +600,7 @@ func (s *TaskSerializer) transferChildWorkflowTaskToProto( } } -func (s *TaskSerializer) transferChildWorkflowTaskFromProto( +func (s *taskSerializer) transferChildWorkflowTaskFromProto( signalTask *persistencespb.TransferTaskInfo, ) *tasks.StartChildExecutionTask { return &tasks.StartChildExecutionTask{ @@ -611,7 +618,7 @@ func (s *TaskSerializer) transferChildWorkflowTaskFromProto( } } -func (s *TaskSerializer) transferCloseTaskToProto( +func (s *taskSerializer) transferCloseTaskToProto( closeTask *tasks.CloseExecutionTask, ) *persistencespb.TransferTaskInfo { return &persistencespb.TransferTaskInfo{ @@ -639,7 +646,7 @@ func (s *TaskSerializer) transferCloseTaskToProto( } } -func (s *TaskSerializer) transferCloseTaskFromProto( +func (s *taskSerializer) transferCloseTaskFromProto( closeTask *persistencespb.TransferTaskInfo, ) *tasks.CloseExecutionTask { return &tasks.CloseExecutionTask{ @@ -657,7 +664,7 @@ func (s *TaskSerializer) transferCloseTaskFromProto( } } -func (s *TaskSerializer) transferResetTaskToProto( +func (s *taskSerializer) transferResetTaskToProto( resetTask *tasks.ResetWorkflowTask, ) *persistencespb.TransferTaskInfo { return &persistencespb.TransferTaskInfo{ @@ -677,7 +684,7 @@ func (s *TaskSerializer) transferResetTaskToProto( } } -func (s *TaskSerializer) transferResetTaskFromProto( +func (s *taskSerializer) transferResetTaskFromProto( resetTask *persistencespb.TransferTaskInfo, ) *tasks.ResetWorkflowTask { return &tasks.ResetWorkflowTask{ @@ -692,7 +699,7 @@ func (s *TaskSerializer) transferResetTaskFromProto( } } -func (s *TaskSerializer) transferDeleteExecutionTaskToProto( +func (s *taskSerializer) transferDeleteExecutionTaskToProto( deleteExecutionTask *tasks.DeleteExecutionTask, ) *persistencespb.TransferTaskInfo { return &persistencespb.TransferTaskInfo{ @@ -705,7 +712,7 @@ func (s *TaskSerializer) transferDeleteExecutionTaskToProto( } } -func (s *TaskSerializer) transferDeleteExecutionTaskFromProto( +func (s *taskSerializer) transferDeleteExecutionTaskFromProto( deleteExecutionTask *persistencespb.TransferTaskInfo, ) *tasks.DeleteExecutionTask { return &tasks.DeleteExecutionTask{ @@ -721,7 +728,7 @@ func (s *TaskSerializer) transferDeleteExecutionTaskFromProto( } } -func (s *TaskSerializer) timerWorkflowTaskToProto( +func (s *taskSerializer) timerWorkflowTaskToProto( workflowTimer *tasks.WorkflowTaskTimeoutTask, ) *persistencespb.TimerTaskInfo { return &persistencespb.TimerTaskInfo{ @@ -739,7 +746,7 @@ func (s *TaskSerializer) timerWorkflowTaskToProto( } } -func (s *TaskSerializer) timerWorkflowTaskFromProto( +func (s *taskSerializer) timerWorkflowTaskFromProto( workflowTimer *persistencespb.TimerTaskInfo, ) *tasks.WorkflowTaskTimeoutTask { return &tasks.WorkflowTaskTimeoutTask{ @@ -757,7 +764,7 @@ func (s *TaskSerializer) timerWorkflowTaskFromProto( } } -func (s *TaskSerializer) timerWorkflowDelayTaskToProto( +func (s *taskSerializer) timerWorkflowDelayTaskToProto( workflowDelayTimer *tasks.WorkflowBackoffTimerTask, ) *persistencespb.TimerTaskInfo { return &persistencespb.TimerTaskInfo{ @@ -775,7 +782,7 @@ func (s *TaskSerializer) timerWorkflowDelayTaskToProto( } } -func (s *TaskSerializer) timerWorkflowDelayTaskFromProto( +func (s *taskSerializer) timerWorkflowDelayTaskFromProto( workflowDelayTimer *persistencespb.TimerTaskInfo, ) *tasks.WorkflowBackoffTimerTask { return &tasks.WorkflowBackoffTimerTask{ @@ -791,7 +798,7 @@ func (s *TaskSerializer) timerWorkflowDelayTaskFromProto( } } -func (s *TaskSerializer) timerActivityTaskToProto( +func (s *taskSerializer) timerActivityTaskToProto( activityTimer *tasks.ActivityTimeoutTask, ) *persistencespb.TimerTaskInfo { return &persistencespb.TimerTaskInfo{ @@ -809,7 +816,7 @@ func (s *TaskSerializer) timerActivityTaskToProto( } } -func (s *TaskSerializer) timerActivityTaskFromProto( +func (s *taskSerializer) timerActivityTaskFromProto( activityTimer *persistencespb.TimerTaskInfo, ) *tasks.ActivityTimeoutTask { return &tasks.ActivityTimeoutTask{ @@ -827,7 +834,7 @@ func (s *TaskSerializer) timerActivityTaskFromProto( } } -func (s *TaskSerializer) timerActivityRetryTaskToProto( +func (s *taskSerializer) timerActivityRetryTaskToProto( activityRetryTimer *tasks.ActivityRetryTimerTask, ) *persistencespb.TimerTaskInfo { return &persistencespb.TimerTaskInfo{ @@ -846,7 +853,7 @@ func (s *TaskSerializer) timerActivityRetryTaskToProto( } } -func (s *TaskSerializer) timerActivityRetryTaskFromProto( +func (s *taskSerializer) timerActivityRetryTaskFromProto( activityRetryTimer *persistencespb.TimerTaskInfo, ) *tasks.ActivityRetryTimerTask { return &tasks.ActivityRetryTimerTask{ @@ -864,7 +871,7 @@ func (s *TaskSerializer) timerActivityRetryTaskFromProto( } } -func (s *TaskSerializer) timerUserTaskToProto( +func (s *taskSerializer) timerUserTaskToProto( userTimer *tasks.UserTimerTask, ) *persistencespb.TimerTaskInfo { return &persistencespb.TimerTaskInfo{ @@ -881,7 +888,7 @@ func (s *TaskSerializer) timerUserTaskToProto( } } -func (s *TaskSerializer) timerUserTaskFromProto( +func (s *taskSerializer) timerUserTaskFromProto( userTimer *persistencespb.TimerTaskInfo, ) *tasks.UserTimerTask { return &tasks.UserTimerTask{ @@ -896,7 +903,7 @@ func (s *TaskSerializer) timerUserTaskFromProto( } } -func (s *TaskSerializer) timerWorkflowRunToProto( +func (s *taskSerializer) timerWorkflowRunToProto( workflowRunTimer *tasks.WorkflowRunTimeoutTask, ) *persistencespb.TimerTaskInfo { return &persistencespb.TimerTaskInfo{ @@ -914,7 +921,7 @@ func (s *TaskSerializer) timerWorkflowRunToProto( } } -func (s *TaskSerializer) timerWorkflowExecutionToProto( +func (s *taskSerializer) timerWorkflowExecutionToProto( workflowExecutionTimer *tasks.WorkflowExecutionTimeoutTask, ) *persistencespb.TimerTaskInfo { return &persistencespb.TimerTaskInfo{ @@ -933,7 +940,7 @@ func (s *TaskSerializer) timerWorkflowExecutionToProto( } } -func (s *TaskSerializer) timerWorkflowRunFromProto( +func (s *taskSerializer) timerWorkflowRunFromProto( workflowRunTimer *persistencespb.TimerTaskInfo, ) *tasks.WorkflowRunTimeoutTask { return &tasks.WorkflowRunTimeoutTask{ @@ -948,7 +955,7 @@ func (s *TaskSerializer) timerWorkflowRunFromProto( } } -func (s *TaskSerializer) timerWorkflowExecutionFromProto( +func (s *taskSerializer) timerWorkflowExecutionFromProto( workflowExecutionTimer *persistencespb.TimerTaskInfo, ) *tasks.WorkflowExecutionTimeoutTask { return &tasks.WorkflowExecutionTimeoutTask{ @@ -960,7 +967,7 @@ func (s *TaskSerializer) timerWorkflowExecutionFromProto( } } -func (s *TaskSerializer) timerWorkflowCleanupTaskToProto( +func (s *taskSerializer) timerWorkflowCleanupTaskToProto( workflowCleanupTimer *tasks.DeleteHistoryEventTask, ) *persistencespb.TimerTaskInfo { return &persistencespb.TimerTaskInfo{ @@ -982,7 +989,7 @@ func (s *TaskSerializer) timerWorkflowCleanupTaskToProto( } } -func (s *TaskSerializer) stateMachineTimerTaskToProto(task *tasks.StateMachineTimerTask) *persistencespb.TimerTaskInfo { +func (s *taskSerializer) stateMachineTimerTaskToProto(task *tasks.StateMachineTimerTask) *persistencespb.TimerTaskInfo { return &persistencespb.TimerTaskInfo{ NamespaceId: task.NamespaceID, WorkflowId: task.WorkflowID, @@ -994,7 +1001,7 @@ func (s *TaskSerializer) stateMachineTimerTaskToProto(task *tasks.StateMachineTi } } -func (s *TaskSerializer) timerWorkflowCleanupTaskFromProto( +func (s *taskSerializer) timerWorkflowCleanupTaskFromProto( workflowCleanupTimer *persistencespb.TimerTaskInfo, ) *tasks.DeleteHistoryEventTask { return &tasks.DeleteHistoryEventTask{ @@ -1012,7 +1019,7 @@ func (s *TaskSerializer) timerWorkflowCleanupTaskFromProto( } } -func (s *TaskSerializer) stateMachineTimerTaskFromProto(info *persistencespb.TimerTaskInfo) *tasks.StateMachineTimerTask { +func (s *taskSerializer) stateMachineTimerTaskFromProto(info *persistencespb.TimerTaskInfo) *tasks.StateMachineTimerTask { return &tasks.StateMachineTimerTask{ WorkflowKey: definition.NewWorkflowKey( info.NamespaceId, @@ -1025,7 +1032,7 @@ func (s *TaskSerializer) stateMachineTimerTaskFromProto(info *persistencespb.Tim } } -func (s *TaskSerializer) visibilityStartTaskToProto( +func (s *taskSerializer) visibilityStartTaskToProto( startVisibilityTask *tasks.StartExecutionVisibilityTask, ) *persistencespb.VisibilityTaskInfo { return &persistencespb.VisibilityTaskInfo{ @@ -1039,7 +1046,7 @@ func (s *TaskSerializer) visibilityStartTaskToProto( } } -func (s *TaskSerializer) visibilityStartTaskFromProto( +func (s *taskSerializer) visibilityStartTaskFromProto( startVisibilityTask *persistencespb.VisibilityTaskInfo, ) *tasks.StartExecutionVisibilityTask { return &tasks.StartExecutionVisibilityTask{ @@ -1054,7 +1061,7 @@ func (s *TaskSerializer) visibilityStartTaskFromProto( } } -func (s *TaskSerializer) visibilityUpsertTaskToProto( +func (s *taskSerializer) visibilityUpsertTaskToProto( upsertVisibilityTask *tasks.UpsertExecutionVisibilityTask, ) *persistencespb.VisibilityTaskInfo { return &persistencespb.VisibilityTaskInfo{ @@ -1067,7 +1074,7 @@ func (s *TaskSerializer) visibilityUpsertTaskToProto( } } -func (s *TaskSerializer) visibilityUpsertTaskFromProto( +func (s *taskSerializer) visibilityUpsertTaskFromProto( upsertVisibilityTask *persistencespb.VisibilityTaskInfo, ) *tasks.UpsertExecutionVisibilityTask { return &tasks.UpsertExecutionVisibilityTask{ @@ -1081,7 +1088,7 @@ func (s *TaskSerializer) visibilityUpsertTaskFromProto( } } -func (s *TaskSerializer) visibilityCloseTaskToProto( +func (s *taskSerializer) visibilityCloseTaskToProto( closetVisibilityTask *tasks.CloseExecutionVisibilityTask, ) *persistencespb.VisibilityTaskInfo { return &persistencespb.VisibilityTaskInfo{ @@ -1095,7 +1102,7 @@ func (s *TaskSerializer) visibilityCloseTaskToProto( } } -func (s *TaskSerializer) visibilityCloseTaskFromProto( +func (s *taskSerializer) visibilityCloseTaskFromProto( closeVisibilityTask *persistencespb.VisibilityTaskInfo, ) *tasks.CloseExecutionVisibilityTask { return &tasks.CloseExecutionVisibilityTask{ @@ -1110,7 +1117,7 @@ func (s *TaskSerializer) visibilityCloseTaskFromProto( } } -func (s *TaskSerializer) visibilityDeleteTaskToProto( +func (s *taskSerializer) visibilityDeleteTaskToProto( deleteVisibilityTask *tasks.DeleteExecutionVisibilityTask, ) *persistencespb.VisibilityTaskInfo { return &persistencespb.VisibilityTaskInfo{ @@ -1125,7 +1132,7 @@ func (s *TaskSerializer) visibilityDeleteTaskToProto( } } -func (s *TaskSerializer) visibilityDeleteTaskFromProto( +func (s *taskSerializer) visibilityDeleteTaskFromProto( deleteVisibilityTask *persistencespb.VisibilityTaskInfo, ) *tasks.DeleteExecutionVisibilityTask { return &tasks.DeleteExecutionVisibilityTask{ @@ -1141,7 +1148,7 @@ func (s *TaskSerializer) visibilityDeleteTaskFromProto( } } -func (s *TaskSerializer) visibilityChasmTaskToProto(task *tasks.ChasmTask) *persistencespb.VisibilityTaskInfo { +func (s *taskSerializer) visibilityChasmTaskToProto(task *tasks.ChasmTask) *persistencespb.VisibilityTaskInfo { return &persistencespb.VisibilityTaskInfo{ NamespaceId: task.WorkflowKey.NamespaceID, WorkflowId: task.WorkflowKey.WorkflowID, @@ -1155,7 +1162,7 @@ func (s *TaskSerializer) visibilityChasmTaskToProto(task *tasks.ChasmTask) *pers } } -func (s *TaskSerializer) visibilityChasmTaskFromProto(task *persistencespb.VisibilityTaskInfo) tasks.Task { +func (s *taskSerializer) visibilityChasmTaskFromProto(task *persistencespb.VisibilityTaskInfo) tasks.Task { return &tasks.ChasmTask{ WorkflowKey: definition.NewWorkflowKey( task.NamespaceId, @@ -1169,7 +1176,7 @@ func (s *TaskSerializer) visibilityChasmTaskFromProto(task *persistencespb.Visib } } -func (s *TaskSerializer) replicationActivityTaskToProto( +func (s *taskSerializer) replicationActivityTaskToProto( activityTask *tasks.SyncActivityTask, ) *persistencespb.ReplicationTaskInfo { return &persistencespb.ReplicationTaskInfo{ @@ -1189,7 +1196,7 @@ func (s *TaskSerializer) replicationActivityTaskToProto( } } -func (s *TaskSerializer) replicationActivityTaskFromProto( +func (s *taskSerializer) replicationActivityTaskFromProto( activityTask *persistencespb.ReplicationTaskInfo, ) *tasks.SyncActivityTask { visibilityTimestamp := time.Unix(0, 0) @@ -1210,7 +1217,7 @@ func (s *TaskSerializer) replicationActivityTaskFromProto( } } -func (s *TaskSerializer) replicationHistoryTaskToProto( +func (s *taskSerializer) replicationHistoryTaskToProto( historyTask *tasks.HistoryReplicationTask, ) *persistencespb.ReplicationTaskInfo { return &persistencespb.ReplicationTaskInfo{ @@ -1231,7 +1238,7 @@ func (s *TaskSerializer) replicationHistoryTaskToProto( } } -func (s *TaskSerializer) replicationHistoryTaskFromProto( +func (s *taskSerializer) replicationHistoryTaskFromProto( historyTask *persistencespb.ReplicationTaskInfo, ) *tasks.HistoryReplicationTask { visibilityTimestamp := time.Unix(0, 0) @@ -1256,7 +1263,7 @@ func (s *TaskSerializer) replicationHistoryTaskFromProto( } } -func (s *TaskSerializer) archiveExecutionTaskToProto( +func (s *taskSerializer) archiveExecutionTaskToProto( archiveExecutionTask *tasks.ArchiveExecutionTask, ) *persistencespb.ArchivalTaskInfo { return &persistencespb.ArchivalTaskInfo{ @@ -1270,7 +1277,7 @@ func (s *TaskSerializer) archiveExecutionTaskToProto( } } -func (s *TaskSerializer) archiveExecutionTaskFromProto( +func (s *taskSerializer) archiveExecutionTaskFromProto( archivalTaskInfo *persistencespb.ArchivalTaskInfo, ) *tasks.ArchiveExecutionTask { visibilityTimestamp := time.Unix(0, 0) @@ -1289,7 +1296,7 @@ func (s *TaskSerializer) archiveExecutionTaskFromProto( } } -func (s *TaskSerializer) replicationSyncWorkflowStateTaskToProto( +func (s *taskSerializer) replicationSyncWorkflowStateTaskToProto( syncWorkflowStateTask *tasks.SyncWorkflowStateTask, ) *persistencespb.ReplicationTaskInfo { return &persistencespb.ReplicationTaskInfo{ @@ -1305,7 +1312,7 @@ func (s *TaskSerializer) replicationSyncWorkflowStateTaskToProto( } } -func (s *TaskSerializer) replicationSyncWorkflowStateTaskFromProto( +func (s *taskSerializer) replicationSyncWorkflowStateTaskFromProto( syncWorkflowStateTask *persistencespb.ReplicationTaskInfo, ) *tasks.SyncWorkflowStateTask { visibilityTimestamp := time.Unix(0, 0) @@ -1326,7 +1333,7 @@ func (s *TaskSerializer) replicationSyncWorkflowStateTaskFromProto( } } -func (s *TaskSerializer) replicationSyncHSMTaskToProto( +func (s *taskSerializer) replicationSyncHSMTaskToProto( syncHSMTask *tasks.SyncHSMTask, ) *persistencespb.ReplicationTaskInfo { return &persistencespb.ReplicationTaskInfo{ @@ -1340,7 +1347,7 @@ func (s *TaskSerializer) replicationSyncHSMTaskToProto( } } -func (s *TaskSerializer) replicationSyncHSMTaskFromProto( +func (s *taskSerializer) replicationSyncHSMTaskFromProto( syncHSMTask *persistencespb.ReplicationTaskInfo, ) *tasks.SyncHSMTask { visibilityTimestamp := time.Unix(0, 0) @@ -1359,12 +1366,12 @@ func (s *TaskSerializer) replicationSyncHSMTaskFromProto( } } -func (s *TaskSerializer) replicationSyncVersionedTransitionTaskToProto( +func (s *taskSerializer) replicationSyncVersionedTransitionTaskToProto( syncVersionedTransitionTask *tasks.SyncVersionedTransitionTask, ) (*persistencespb.ReplicationTaskInfo, error) { taskInfoEquivalents := make([]*persistencespb.ReplicationTaskInfo, 0, len(syncVersionedTransitionTask.TaskEquivalents)) for _, task := range syncVersionedTransitionTask.TaskEquivalents { - taskInfoEquivalent, err := s.ParseReplicationTaskInfo(task) + taskInfoEquivalent, err := s.SeralizeReplicationTask(task) if err != nil { return nil, err } @@ -1390,13 +1397,13 @@ func (s *TaskSerializer) replicationSyncVersionedTransitionTaskToProto( }, nil } -func (s *TaskSerializer) replicationSyncVersionedTransitionTaskFromProto( +func (s *taskSerializer) replicationSyncVersionedTransitionTaskFromProto( syncVersionedTransitionTask *persistencespb.ReplicationTaskInfo, ) (*tasks.SyncVersionedTransitionTask, error) { taskEquivalents := make([]tasks.Task, 0, len(syncVersionedTransitionTask.TaskEquivalents)) for _, taskInfoEquivalent := range syncVersionedTransitionTask.TaskEquivalents { - taskEquivalent, err := s.ParseReplicationTask(taskInfoEquivalent) + taskEquivalent, err := s.DeserializeReplicationTask(taskInfoEquivalent) if err != nil { return nil, err } @@ -1427,10 +1434,13 @@ func (s *TaskSerializer) replicationSyncVersionedTransitionTaskFromProto( }, nil } -func (s *TaskSerializer) serializeOutboundTask(task tasks.Task) (*commonpb.DataBlob, error) { +func (s *taskSerializer) serializeOutboundTask( + task tasks.Task, +) (*commonpb.DataBlob, error) { + var outboundTaskInfo *persistencespb.OutboundTaskInfo switch task := task.(type) { case *tasks.StateMachineOutboundTask: - return ProtoEncode(&persistencespb.OutboundTaskInfo{ + outboundTaskInfo = &persistencespb.OutboundTaskInfo{ NamespaceId: task.NamespaceID, WorkflowId: task.WorkflowID, RunId: task.RunID, @@ -1441,9 +1451,9 @@ func (s *TaskSerializer) serializeOutboundTask(task tasks.Task) (*commonpb.DataB TaskDetails: &persistencespb.OutboundTaskInfo_StateMachineInfo{ StateMachineInfo: task.Info, }, - }) + } case *tasks.ChasmTask: - return ProtoEncode(&persistencespb.OutboundTaskInfo{ + outboundTaskInfo = &persistencespb.OutboundTaskInfo{ NamespaceId: task.NamespaceID, WorkflowId: task.WorkflowID, RunId: task.RunID, @@ -1453,18 +1463,21 @@ func (s *TaskSerializer) serializeOutboundTask(task tasks.Task) (*commonpb.DataB VisibilityTime: timestamppb.New(task.VisibilityTimestamp), TaskDetails: &persistencespb.OutboundTaskInfo_ChasmTaskInfo{ ChasmTaskInfo: task.Info, - }}) + }, + } default: return nil, serviceerror.NewInternalf("unknown outbound task type while serializing: %v", task) } + return s.serializer.OutboundTaskInfoToBlob(outboundTaskInfo) } -func (s *TaskSerializer) deserializeOutboundTask(blob *commonpb.DataBlob) (tasks.Task, error) { - info := &persistencespb.OutboundTaskInfo{} - if err := Decode(blob, info); err != nil { +func (s *taskSerializer) deserializeOutboundTask( + blob *commonpb.DataBlob, +) (tasks.Task, error) { + info, err := s.serializer.OutboundTaskInfoFromBlob(blob) + if err != nil { return nil, err } - switch info.TaskType { case enumsspb.TASK_TYPE_STATE_MACHINE_OUTBOUND: return &tasks.StateMachineOutboundTask{ diff --git a/common/persistence/serialization/task_serializer_test.go b/common/persistence/serialization/task_serializer_test.go index 6b5bd2c96b1..449be1b8c9f 100644 --- a/common/persistence/serialization/task_serializer_test.go +++ b/common/persistence/serialization/task_serializer_test.go @@ -28,7 +28,7 @@ type ( *require.Assertions workflowKey definition.WorkflowKey - taskSerializer *TaskSerializer + taskSerializer TaskSerializer } ) @@ -52,7 +52,7 @@ func (s *taskSerializerSuite) SetupTest() { "random workflow ID", "random run ID", ) - s.taskSerializer = NewTaskSerializer() + s.taskSerializer = NewTaskSerializer(NewSerializer()) } func (s *taskSerializerSuite) TearDownTest() { diff --git a/common/persistence/sql/common.go b/common/persistence/sql/common.go index 505818c9a9b..a11ccfc0ca3 100644 --- a/common/persistence/sql/common.go +++ b/common/persistence/sql/common.go @@ -13,19 +13,22 @@ import ( "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/sql/sqlplugin" ) // TODO: Rename all SQL Managers to Stores type SqlStore struct { - DB sqlplugin.DB - logger log.Logger + DB sqlplugin.DB + logger log.Logger + serializer serialization.Serializer } func NewSqlStore(db sqlplugin.DB, logger log.Logger) SqlStore { return SqlStore{ - DB: db, - logger: logger, + DB: db, + logger: logger, + serializer: serialization.NewSerializer(), } } diff --git a/common/persistence/sql/execution.go b/common/persistence/sql/execution.go index d560a5ce948..9b44ef5d8ce 100644 --- a/common/persistence/sql/execution.go +++ b/common/persistence/sql/execution.go @@ -12,13 +12,14 @@ import ( persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common/log" p "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/sql/sqlplugin" "go.temporal.io/server/common/primitives" ) type sqlExecutionStore struct { SqlStore - p.HistoryBranchUtilImpl + p.HistoryBranchUtil } var _ p.ExecutionStore = (*sqlExecutionStore)(nil) @@ -27,10 +28,11 @@ var _ p.ExecutionStore = (*sqlExecutionStore)(nil) func NewSQLExecutionStore( db sqlplugin.DB, logger log.Logger, + serializer serialization.Serializer, ) (p.ExecutionStore, error) { - return &sqlExecutionStore{ - SqlStore: NewSqlStore(db, logger), + SqlStore: NewSqlStore(db, logger), + HistoryBranchUtil: p.NewHistoryBranchUtil(serializer), }, nil } @@ -423,7 +425,7 @@ func (m *sqlExecutionStore) updateWorkflowExecutionTx( return serviceerror.NewUnavailablef("UpdateWorkflowExecution: unknown mode: %v", request.Mode) } - if err := applyWorkflowMutationTx(ctx, tx, shardID, &updateWorkflow); err != nil { + if err := m.applyWorkflowMutationTx(ctx, tx, shardID, &updateWorkflow); err != nil { return err } @@ -534,7 +536,7 @@ func (m *sqlExecutionStore) conflictResolveWorkflowExecutionTx( return serviceerror.NewUnavailablef("ConflictResolveWorkflowExecution: unknown mode: %v", request.Mode) } - if err := applyWorkflowSnapshotTxAsReset(ctx, + if err := m.applyWorkflowSnapshotTxAsReset(ctx, tx, shardID, &resetWorkflow, @@ -543,7 +545,7 @@ func (m *sqlExecutionStore) conflictResolveWorkflowExecutionTx( } if currentWorkflow != nil { - if err := applyWorkflowMutationTx(ctx, + if err := m.applyWorkflowMutationTx(ctx, tx, shardID, currentWorkflow, @@ -713,7 +715,7 @@ func (m *sqlExecutionStore) setWorkflowExecutionTx( shardID := request.ShardID setSnapshot := request.SetWorkflowSnapshot - return applyWorkflowSnapshotTxAsReset(ctx, + return m.applyWorkflowSnapshotTxAsReset(ctx, tx, shardID, &setSnapshot, @@ -727,6 +729,10 @@ func (m *sqlExecutionStore) ListConcreteExecutions( return nil, serviceerror.NewUnimplemented("ListConcreteExecutions is not implemented") } +func (m *sqlExecutionStore) GetHistoryBranchUtil() p.HistoryBranchUtil { + return m.HistoryBranchUtil +} + func getStartTimeFromState(state *persistencespb.WorkflowExecutionState) *time.Time { if state == nil || state.StartTime == nil { return nil diff --git a/common/persistence/sql/execution_tasks.go b/common/persistence/sql/execution_tasks.go index b9b6aff41e5..dec53bc3869 100644 --- a/common/persistence/sql/execution_tasks.go +++ b/common/persistence/sql/execution_tasks.go @@ -10,7 +10,6 @@ import ( "go.temporal.io/api/serviceerror" p "go.temporal.io/server/common/persistence" - "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/sql/sqlplugin" "go.temporal.io/server/service/history/tasks" ) @@ -586,7 +585,7 @@ func (m *sqlExecutionStore) PutReplicationTaskToDLQ( request *p.PutReplicationTaskToDLQRequest, ) error { replicationTask := request.TaskInfo - blob, err := serialization.ReplicationTaskInfoToBlob(replicationTask) + blob, err := m.serializer.ReplicationTaskInfoToBlob(replicationTask) if err != nil { return err diff --git a/common/persistence/sql/execution_util.go b/common/persistence/sql/execution_util.go index 1b09c14f582..cd02c651dc9 100644 --- a/common/persistence/sql/execution_util.go +++ b/common/persistence/sql/execution_util.go @@ -19,7 +19,7 @@ import ( "go.temporal.io/server/service/history/tasks" ) -func applyWorkflowMutationTx( +func (m *sqlExecutionStore) applyWorkflowMutationTx( ctx context.Context, tx sqlplugin.Tx, shardID int32, @@ -58,7 +58,7 @@ func applyWorkflowMutationTx( } } - if err := updateExecution(ctx, + if err := m.updateExecution(ctx, tx, namespaceID, workflowID, @@ -189,13 +189,12 @@ func applyWorkflowMutationTx( return nil } -func applyWorkflowSnapshotTxAsReset( +func (m *sqlExecutionStore) applyWorkflowSnapshotTxAsReset( ctx context.Context, tx sqlplugin.Tx, shardID int32, workflowSnapshot *p.InternalWorkflowSnapshot, ) error { - lastWriteVersion := workflowSnapshot.LastWriteVersion workflowID := workflowSnapshot.WorkflowID namespaceID := workflowSnapshot.NamespaceID @@ -227,7 +226,7 @@ func applyWorkflowSnapshotTxAsReset( } } - if err := updateExecution(ctx, + if err := m.updateExecution(ctx, tx, namespaceID, workflowID, @@ -1057,7 +1056,7 @@ func updateCurrentExecution( return nil } -func buildExecutionRow( +func (m *sqlExecutionStore) buildExecutionRow( namespaceID string, workflowID string, executionInfo *commonpb.DataBlob, @@ -1069,7 +1068,7 @@ func buildExecutionRow( ) (row *sqlplugin.ExecutionsRow, err error) { // TODO: double encoding execution state? executionState could've been passed to the function as // *commonpb.DataBlob like executionInfo - stateBlob, err := serialization.WorkflowExecutionStateToBlob(executionState) + stateBlob, err := m.serializer.WorkflowExecutionStateToBlob(executionState) if err != nil { return nil, err } @@ -1112,7 +1111,7 @@ func (m *sqlExecutionStore) createExecution( shardID int32, ) error { - row, err := buildExecutionRow( + row, err := m.buildExecutionRow( namespaceID, workflowID, executionInfo, @@ -1147,7 +1146,7 @@ func (m *sqlExecutionStore) createExecution( return nil } -func updateExecution( +func (m *sqlExecutionStore) updateExecution( ctx context.Context, tx sqlplugin.Tx, namespaceID string, @@ -1159,7 +1158,7 @@ func updateExecution( dbRecordVersion int64, shardID int32, ) error { - row, err := buildExecutionRow( + row, err := m.buildExecutionRow( namespaceID, workflowID, executionInfo, @@ -1191,7 +1190,7 @@ func workflowExecutionStateFromCurrentExecutionsRow( row *sqlplugin.CurrentExecutionsRow, ) (*persistencespb.WorkflowExecutionState, error) { if len(row.Data) > 0 && row.DataEncoding != "" { - return serialization.WorkflowExecutionStateFromBlob(p.NewDataBlob(row.Data, row.DataEncoding)) + return serialization.DefaultDecoder.WorkflowExecutionStateFromBlob(p.NewDataBlob(row.Data, row.DataEncoding)) } // Old records don't have the serialized WorkflowExecutionState stored in DB. diff --git a/common/persistence/sql/factory.go b/common/persistence/sql/factory.go index 6caf9a05584..6b1558ae2d4 100644 --- a/common/persistence/sql/factory.go +++ b/common/persistence/sql/factory.go @@ -8,6 +8,7 @@ import ( "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" p "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/sql/sqlplugin" "go.temporal.io/server/common/resolver" ) @@ -19,6 +20,7 @@ type ( mainDBConn DbConn clusterName string logger log.Logger + serializer serialization.Serializer } // DbConn represents a logical mysql connection - its a @@ -44,6 +46,7 @@ func NewFactory( cfg config.SQL, r resolver.ServiceResolver, clusterName string, + serializer serialization.Serializer, logger log.Logger, metricsHandler metrics.Handler, ) *Factory { @@ -51,6 +54,7 @@ func NewFactory( cfg: cfg, clusterName: clusterName, logger: logger, + serializer: serializer, mainDBConn: NewRefCountedDBConn(sqlplugin.DbKindMain, &cfg, r, logger, metricsHandler), } } @@ -115,7 +119,7 @@ func (f *Factory) NewExecutionStore() (p.ExecutionStore, error) { if err != nil { return nil, err } - return NewSQLExecutionStore(conn, f.logger) + return NewSQLExecutionStore(conn, f.logger, f.serializer) } // NewQueue returns a new queue backed by sql diff --git a/common/persistence/sql/history_store.go b/common/persistence/sql/history_store.go index f961af1bc11..9703b6b32bd 100644 --- a/common/persistence/sql/history_store.go +++ b/common/persistence/sql/history_store.go @@ -157,7 +157,7 @@ func (m *sqlExecutionStore) ReadHistoryBranch( ctx context.Context, request *p.InternalReadHistoryBranchRequest, ) (*p.InternalReadHistoryBranchResponse, error) { - branch, err := m.GetHistoryBranchUtil().ParseHistoryBranchInfo(request.BranchToken) + branch, err := m.ParseHistoryBranchInfo(request.BranchToken) if err != nil { return nil, err } @@ -449,7 +449,7 @@ func (m *sqlExecutionStore) GetHistoryTreeContainingBranch( ctx context.Context, request *p.InternalGetHistoryTreeContainingBranchRequest, ) (*p.InternalGetHistoryTreeContainingBranchResponse, error) { - branch, err := m.GetHistoryBranchUtil().ParseHistoryBranchInfo(request.BranchToken) + branch, err := m.ParseHistoryBranchInfo(request.BranchToken) if err != nil { return nil, err } diff --git a/common/persistence/sql/sqlplugin/tests/history_current_execution.go b/common/persistence/sql/sqlplugin/tests/history_current_execution.go index e238eb12d00..93f3391abb1 100644 --- a/common/persistence/sql/sqlplugin/tests/history_current_execution.go +++ b/common/persistence/sql/sqlplugin/tests/history_current_execution.go @@ -323,7 +323,7 @@ func (s *historyCurrentExecutionSuite) newRandomCurrentExecutionRow( }, }, } - executionStateBlob, _ := serialization.WorkflowExecutionStateToBlob(executionState) + executionStateBlob, _ := serialization.NewSerializer().WorkflowExecutionStateToBlob(executionState) return sqlplugin.CurrentExecutionsRow{ ShardID: shardID, NamespaceID: namespaceID, diff --git a/common/persistence/tests/cassandra_test.go b/common/persistence/tests/cassandra_test.go index 7426cc4ef9e..5619cbf10f7 100644 --- a/common/persistence/tests/cassandra_test.go +++ b/common/persistence/tests/cassandra_test.go @@ -155,7 +155,6 @@ func TestCassandraExecutionMutableStateStoreSuite(t *testing.T) { shardStore, executionStore, serialization.NewSerializer(), - &persistence.HistoryBranchUtilImpl{}, testData.Logger, ) suite.Run(t, s) diff --git a/common/persistence/tests/cassandra_test_util.go b/common/persistence/tests/cassandra_test_util.go index 6a06aaf2a88..6a63789f2f7 100644 --- a/common/persistence/tests/cassandra_test_util.go +++ b/common/persistence/tests/cassandra_test_util.go @@ -18,6 +18,7 @@ import ( p "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/cassandra" commongocql "go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/gocql" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/resolver" "go.temporal.io/server/common/shuffle" "go.temporal.io/server/temporal/environment" @@ -61,6 +62,7 @@ func setUpCassandraTest(t *testing.T) (CassandraTestData, func()) { testCassandraClusterName, testData.Logger, metrics.NoopMetricsHandler, + serialization.NewSerializer(), ) tearDown := func() { diff --git a/common/persistence/tests/execution_mutable_state.go b/common/persistence/tests/execution_mutable_state.go index 62c19918aa0..67e48ce172a 100644 --- a/common/persistence/tests/execution_mutable_state.go +++ b/common/persistence/tests/execution_mutable_state.go @@ -56,7 +56,6 @@ func NewExecutionMutableStateSuite( shardStore p.ShardStore, executionStore p.ExecutionStore, serializer serialization.Serializer, - historyBranchUtil p.HistoryBranchUtil, logger log.Logger, ) *ExecutionMutableStateSuite { return &ExecutionMutableStateSuite{ @@ -70,10 +69,11 @@ func NewExecutionMutableStateSuite( executionStore, serializer, nil, + nil, logger, dynamicconfig.GetIntPropertyFn(4*1024*1024), ), - historyBranchUtil: historyBranchUtil, + historyBranchUtil: p.NewHistoryBranchUtil(serializer), Logger: logger, } } diff --git a/common/persistence/tests/execution_mutable_state_task.go b/common/persistence/tests/execution_mutable_state_task.go index 446ec82562f..be844953221 100644 --- a/common/persistence/tests/execution_mutable_state_task.go +++ b/common/persistence/tests/execution_mutable_state_task.go @@ -43,7 +43,7 @@ type ( Cancel context.CancelFunc } - testSerializer struct { + testTaskSerializer struct { serialization.Serializer } ) @@ -51,15 +51,6 @@ type ( var ( fakeImmediateTaskCategory = tasks.NewCategory(1234, tasks.CategoryTypeImmediate, "fake-immediate") fakeScheduledTaskCategory = tasks.NewCategory(2345, tasks.CategoryTypeScheduled, "fake-scheduled") - - taskCategories = []tasks.Category{ - tasks.CategoryTransfer, - tasks.CategoryTimer, - tasks.CategoryReplication, - tasks.CategoryVisibility, - fakeImmediateTaskCategory, - fakeScheduledTaskCategory, - } ) func NewExecutionMutableStateTaskSuite( @@ -69,7 +60,7 @@ func NewExecutionMutableStateTaskSuite( serializer serialization.Serializer, logger log.Logger, ) *ExecutionMutableStateTaskSuite { - serializer = newTestSerializer(serializer) + taskSerializer := newTestTaskSerializer(serializer) return &ExecutionMutableStateTaskSuite{ Assertions: require.New(t), ShardManager: p.NewShardManager( @@ -79,6 +70,7 @@ func NewExecutionMutableStateTaskSuite( ExecutionManager: p.NewExecutionManager( executionStore, serializer, + taskSerializer, nil, logger, dynamicconfig.GetIntPropertyFn(4*1024*1024), @@ -657,15 +649,15 @@ func (s *ExecutionMutableStateTaskSuite) GetAndCompleteHistoryTask( s.Empty(historyTasks) } -func newTestSerializer( +func newTestTaskSerializer( serializer serialization.Serializer, -) serialization.Serializer { - return &testSerializer{ +) serialization.TaskSerializer { + return &testTaskSerializer{ Serializer: serializer, } } -func (s *testSerializer) SerializeTask( +func (s *testTaskSerializer) SerializeTask( task tasks.Task, ) (*commonpb.DataBlob, error) { if fakeTask, ok := task.(*tasks.FakeTask); ok { @@ -686,18 +678,18 @@ func (s *testSerializer) SerializeTask( EncodingType: enumspb.ENCODING_TYPE_PROTO3, }, nil } - - return s.Serializer.SerializeTask(task) + return serialization.NewTaskSerializer(s.Serializer).SerializeTask(task) } -func (s *testSerializer) DeserializeTask( +func (s *testTaskSerializer) DeserializeTask( category tasks.Category, blob *commonpb.DataBlob, ) (tasks.Task, error) { categoryID := category.ID() if categoryID != fakeImmediateTaskCategory.ID() && categoryID != fakeScheduledTaskCategory.ID() { - return s.Serializer.DeserializeTask(category, blob) + taskSerializer := serialization.NewTaskSerializer(s.Serializer) + return taskSerializer.DeserializeTask(category, blob) } taskInfo := &persistencespb.TransferTaskInfo{} @@ -718,3 +710,11 @@ func (s *testSerializer) DeserializeTask( return fakeTask, nil } + +func (s *testTaskSerializer) SeralizeReplicationTask(task tasks.Task) (*persistencespb.ReplicationTaskInfo, error) { + panic("not implemented") +} + +func (s *testTaskSerializer) DeserializeReplicationTask(replicationTask *persistencespb.ReplicationTaskInfo) (tasks.Task, error) { + panic("not implemented") +} diff --git a/common/persistence/tests/history_store.go b/common/persistence/tests/history_store.go index c4aeead0d94..cff33545b4f 100644 --- a/common/persistence/tests/history_store.go +++ b/common/persistence/tests/history_store.go @@ -55,18 +55,19 @@ func NewHistoryEventsSuite( store p.ExecutionStore, logger log.Logger, ) *HistoryEventsSuite { - eventSerializer := serialization.NewSerializer() + serializer := serialization.NewSerializer() return &HistoryEventsSuite{ Assertions: require.New(t), ProtoAssertions: protorequire.New(t), store: p.NewExecutionManager( store, - eventSerializer, + serializer, + nil, nil, logger, dynamicconfig.GetIntPropertyFn(4*1024*1024), ), - serializer: eventSerializer, + serializer: serializer, logger: logger, } } diff --git a/common/persistence/tests/mysql_test.go b/common/persistence/tests/mysql_test.go index ad953d99154..b6387f4cfc8 100644 --- a/common/persistence/tests/mysql_test.go +++ b/common/persistence/tests/mysql_test.go @@ -7,7 +7,6 @@ import ( "github.com/stretchr/testify/suite" "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" - "go.temporal.io/server/common/persistence" persistencetests "go.temporal.io/server/common/persistence/persistence-tests" "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/sql" @@ -53,7 +52,6 @@ func TestMySQLExecutionMutableStateStoreSuite(t *testing.T) { shardStore, executionStore, serialization.NewSerializer(), - &persistence.HistoryBranchUtilImpl{}, testData.Logger, ) suite.Run(t, s) diff --git a/common/persistence/tests/mysql_test_util.go b/common/persistence/tests/mysql_test_util.go index 1b432d13c7a..2405e2625bf 100644 --- a/common/persistence/tests/mysql_test_util.go +++ b/common/persistence/tests/mysql_test_util.go @@ -11,6 +11,7 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/metrics/metricstest" p "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/sql" "go.temporal.io/server/common/persistence/sql/sqlplugin" "go.temporal.io/server/common/persistence/sql/sqlplugin/mysql" @@ -58,6 +59,7 @@ func setUpMySQLTest(t *testing.T) (MySQLTestData, func()) { *testData.Cfg, resolver.NewNoopResolver(), testMySQLClusterName, + serialization.NewSerializer(), testData.Logger, mh, ) diff --git a/common/persistence/tests/postgresql_test.go b/common/persistence/tests/postgresql_test.go index 197c92760b1..6d2d61e6058 100644 --- a/common/persistence/tests/postgresql_test.go +++ b/common/persistence/tests/postgresql_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/suite" "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" - "go.temporal.io/server/common/persistence" persistencetests "go.temporal.io/server/common/persistence/persistence-tests" "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/sql" @@ -57,7 +56,6 @@ func (p *PostgreSQLSuite) TestPostgreSQLExecutionMutableStateStoreSuite() { shardStore, executionStore, serialization.NewSerializer(), - &persistence.HistoryBranchUtilImpl{}, testData.Logger, ) suite.Run(p.T(), s) diff --git a/common/persistence/tests/postgresql_test_util.go b/common/persistence/tests/postgresql_test_util.go index c3b0b355049..7df9243fa22 100644 --- a/common/persistence/tests/postgresql_test_util.go +++ b/common/persistence/tests/postgresql_test_util.go @@ -11,6 +11,7 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/metrics/metricstest" p "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/sql" "go.temporal.io/server/common/persistence/sql/sqlplugin" "go.temporal.io/server/common/resolver" @@ -57,6 +58,7 @@ func setUpPostgreSQLTest(t *testing.T, pluginName string) (PostgreSQLTestData, f *testData.Cfg, resolver.NewNoopResolver(), testPostgreSQLClusterName, + serialization.NewSerializer(), testData.Logger, mh, ) diff --git a/common/persistence/tests/sqlite_test.go b/common/persistence/tests/sqlite_test.go index de9f3bc40ec..9593e2c48c5 100644 --- a/common/persistence/tests/sqlite_test.go +++ b/common/persistence/tests/sqlite_test.go @@ -92,6 +92,7 @@ func TestSQLiteExecutionMutableStateStoreSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -112,7 +113,6 @@ func TestSQLiteExecutionMutableStateStoreSuite(t *testing.T) { shardStore, executionStore, serialization.NewSerializer(), - &persistence.HistoryBranchUtilImpl{}, logger, ) suite.Run(t, s) @@ -125,6 +125,7 @@ func TestSQLiteExecutionMutableStateTaskStoreSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -157,6 +158,7 @@ func TestSQLiteHistoryStoreSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -179,6 +181,7 @@ func TestSQLiteTaskQueueSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -201,6 +204,7 @@ func TestSQLiteFairTaskQueueSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -223,6 +227,7 @@ func TestSQLiteTaskQueueTaskSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -245,6 +250,7 @@ func TestSQLiteTaskQueueFairTaskSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -267,6 +273,7 @@ func TestSQLiteTaskQueueUserDataSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -293,6 +300,7 @@ func TestSQLiteFileExecutionMutableStateStoreSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -313,7 +321,6 @@ func TestSQLiteFileExecutionMutableStateStoreSuite(t *testing.T) { shardStore, executionStore, serialization.NewSerializer(), - &persistence.HistoryBranchUtilImpl{}, logger, ) suite.Run(t, s) @@ -330,6 +337,7 @@ func TestSQLiteFileExecutionMutableStateTaskStoreSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -366,6 +374,7 @@ func TestSQLiteFileHistoryStoreSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -392,6 +401,7 @@ func TestSQLiteFileTaskQueueSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -418,6 +428,7 @@ func TestSQLiteFileFairTaskQueueSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -444,6 +455,7 @@ func TestSQLiteFileTaskQueueTaskSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -470,6 +482,7 @@ func TestSQLiteFileTaskQueueFairTaskSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -496,6 +509,7 @@ func TestSQLiteFileTaskQueueUserDataSuite(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -1274,6 +1288,7 @@ func TestSQLiteQueueV2(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) @@ -1292,6 +1307,7 @@ func TestSQLiteNexusEndpointPersistence(t *testing.T) { *cfg, resolver.NewNoopResolver(), testSQLiteClusterName, + serialization.NewSerializer(), logger, metrics.NoopMetricsHandler, ) diff --git a/common/persistence/xdc_cache.go b/common/persistence/xdc_cache.go index 39f5d66a562..e3e01387c7e 100644 --- a/common/persistence/xdc_cache.go +++ b/common/persistence/xdc_cache.go @@ -99,8 +99,7 @@ func NewEventsBlobCache( Pin: false, }, ), - logger: logger, - serializer: serialization.NewSerializer(), + logger: logger, } } diff --git a/common/resourcetest/test_resource.go b/common/resourcetest/test_resource.go index d9d403d1799..0ac0a055922 100644 --- a/common/resourcetest/test_resource.go +++ b/common/resourcetest/test_resource.go @@ -115,7 +115,7 @@ func NewTest(controller *gomock.Controller, serviceName primitives.ServiceName) taskMgr := persistence.NewMockTaskManager(controller) shardMgr := persistence.NewMockShardManager(controller) executionMgr := persistence.NewMockExecutionManager(controller) - executionMgr.EXPECT().GetHistoryBranchUtil().Return(&persistence.HistoryBranchUtilImpl{}).AnyTimes() + executionMgr.EXPECT().GetHistoryBranchUtil().Return(persistence.NewHistoryBranchUtil(serialization.NewSerializer())).AnyTimes() namespaceReplicationQueue := persistence.NewMockNamespaceReplicationQueue(controller) nexusEndpointMgr := persistence.NewMockNexusEndpointManager(controller) diff --git a/docs/development/testing.md b/docs/development/testing.md index 9c6dc01e0ec..dd38a3efd6b 100644 --- a/docs/development/testing.md +++ b/docs/development/testing.md @@ -13,6 +13,7 @@ This document describes the project's testing setup, utilities and best practice - `TEMPORAL_TEST_LOG_FORMAT`: Controls the output format for test logs. Available options: `json` or `console` - `TEMPORAL_TEST_LOG_LEVEL`: Sets the verbosity level for test logging. Available levels: `debug`, `info`, `warn`, `error`, `fatal` - `TEMPORAL_TEST_OTEL_OUTPUT`: Enables OpenTelemetry (OTEL) trace output for failed tests to the provided file path. +- `TEMPORAL_TEST_DATA_ENCODING`: If set, overrides the default data blob encoding. Available options: `json`, `proto3`. ## Test helpers diff --git a/service/history/api/addtasks/api_test.go b/service/history/api/addtasks/api_test.go index 1966c215bed..28a41021615 100644 --- a/service/history/api/addtasks/api_test.go +++ b/service/history/api/addtasks/api_test.go @@ -99,7 +99,7 @@ func TestInvoke(t *testing.T) { WorkflowKey: workflowKey, }, } { - serializer := serialization.NewTaskSerializer() + serializer := serialization.NewTaskSerializer(serialization.NewSerializer()) blob, err := serializer.SerializeTask(task) require.NoError(t, err) params.req.Tasks = append(params.req.Tasks, &historyservice.AddTasksRequest_Task{ @@ -237,7 +237,7 @@ func getDefaultTestParams(t *testing.T) *testParams { task := &tasks.WorkflowTask{ WorkflowKey: definition.NewWorkflowKey(string(tests.NamespaceID), tests.WorkflowID, tests.RunID), } - serializer := serialization.NewTaskSerializer() + serializer := serialization.NewTaskSerializer(serialization.NewSerializer()) blob, err := serializer.SerializeTask(task) require.NoError(t, err) ctrl := gomock.NewController(t) diff --git a/service/history/api/getdlqtasks/getdlqtaskstest/apitest.go b/service/history/api/getdlqtasks/getdlqtaskstest/apitest.go index a368b815318..d8da3a779b9 100644 --- a/service/history/api/getdlqtasks/getdlqtaskstest/apitest.go +++ b/service/history/api/getdlqtasks/getdlqtaskstest/apitest.go @@ -58,7 +58,7 @@ func TestInvoke(t *testing.T, manager persistence.HistoryTaskQueueManager) { require.Equal(t, 1, len(res.DlqTasks)) assert.Equal(t, int64(persistence.FirstQueueMessageID), res.DlqTasks[0].Metadata.MessageId) assert.Equal(t, 1, int(res.DlqTasks[0].Payload.ShardId)) - serializer := serialization.NewTaskSerializer() + serializer := serialization.NewTaskSerializer(serialization.NewSerializer()) outTask, err := serializer.DeserializeTask(tasks.CategoryTransfer, res.DlqTasks[0].Payload.Blob) require.NoError(t, err) assert.Equal(t, inTask, outTask) diff --git a/service/history/handler.go b/service/history/handler.go index 7e1c83604f3..dd91c414e46 100644 --- a/service/history/handler.go +++ b/service/history/handler.go @@ -1,6 +1,7 @@ package history import ( + "cmp" "context" "errors" "math" @@ -1808,9 +1809,10 @@ func (h *Handler) ReapplyEvents(ctx context.Context, request *historyservice.Rea } // deserialize history event object + eventsBlob := request.GetRequest().GetEvents() historyEvents, err := h.payloadSerializer.DeserializeEvents(&commonpb.DataBlob{ - EncodingType: enumspb.ENCODING_TYPE_PROTO3, - Data: request.GetRequest().GetEvents().GetData(), + EncodingType: cmp.Or(eventsBlob.GetEncodingType(), enumspb.ENCODING_TYPE_PROTO3), + Data: eventsBlob.GetData(), }) if err != nil { return nil, h.convertError(err) @@ -2153,7 +2155,7 @@ func (h *Handler) StreamWorkflowReplicationMessages( engine, shardContext, clientClusterName, - serialization.NewSerializer(), + h.payloadSerializer, ), clientClusterName, clientShardCount, diff --git a/service/history/history_engine.go b/service/history/history_engine.go index 5f352a7d314..581cf27f56c 100644 --- a/service/history/history_engine.go +++ b/service/history/history_engine.go @@ -1016,10 +1016,11 @@ func (e *historyEngineImpl) AddTasks( ctx context.Context, request *historyservice.AddTasksRequest, ) (_ *historyservice.AddTasksResponse, retError error) { + taskSerializer := serialization.NewTaskSerializer(e.eventSerializer) return addtasks.Invoke( ctx, e.shardContext, - e.eventSerializer, + taskSerializer, int(e.config.NumberOfShards), request, e.taskCategoryRegistry, diff --git a/service/history/ndc/workflow_state_replicator_test.go b/service/history/ndc/workflow_state_replicator_test.go index 4e191771e63..4021a1f32c4 100644 --- a/service/history/ndc/workflow_state_replicator_test.go +++ b/service/history/ndc/workflow_state_replicator_test.go @@ -54,6 +54,7 @@ type ( mockRemoteAdminClient *adminservicemock.MockAdminServiceClient mockExecutionManager *persistence.MockExecutionManager logger log.Logger + serializer serialization.Serializer workflowID string runID string @@ -100,11 +101,12 @@ func (s *workflowReplicatorSuite) SetupTest() { s.workflowID = "some random workflow ID" s.runID = uuid.New() s.now = time.Now().UTC() + s.serializer = serialization.NewSerializer() s.workflowStateReplicator = NewWorkflowStateReplicator( s.mockShard, s.mockWorkflowCache, eventReapplier, - serialization.NewSerializer(), + s.serializer, s.logger, ) } @@ -122,7 +124,7 @@ func (s *workflowReplicatorSuite) Test_ApplyWorkflowState_BrandNew() { BranchId: uuid.New(), Ancestors: nil, } - historyBranch, err := serialization.HistoryBranchToBlob(branchInfo) + historyBranch, err := s.serializer.HistoryBranchToBlob(branchInfo) s.NoError(err) completionEventBatchId := int64(5) nextEventID := int64(7) @@ -229,7 +231,7 @@ func (s *workflowReplicatorSuite) Test_ApplyWorkflowState_Ancestors() { }, }, } - historyBranch, err := serialization.HistoryBranchToBlob(branchInfo) + historyBranch, err := s.serializer.HistoryBranchToBlob(branchInfo) s.NoError(err) completionEventBatchId := int64(5) nextEventID := int64(7) @@ -341,11 +343,10 @@ func (s *workflowReplicatorSuite) Test_ApplyWorkflowState_Ancestors() { }, }, } - serializer := serialization.NewSerializer() var historyBlobs []*commonpb.DataBlob var nodeIds []int64 for _, history := range expectedHistory { - blob, err := serializer.SerializeEvents(history.GetEvents()) + blob, err := s.serializer.SerializeEvents(history.GetEvents()) s.NoError(err) historyBlobs = append(historyBlobs, blob) nodeIds = append(nodeIds, history.GetEvents()[0].GetEventId()) @@ -414,7 +415,7 @@ func (s *workflowReplicatorSuite) Test_ApplyWorkflowState_ExistWorkflow_Resend() BranchId: uuid.New(), Ancestors: nil, } - historyBranch, err := serialization.HistoryBranchToBlob(branchInfo) + historyBranch, err := s.serializer.HistoryBranchToBlob(branchInfo) s.NoError(err) completionEventBatchId := int64(5) nextEventID := int64(7) @@ -494,7 +495,7 @@ func (s *workflowReplicatorSuite) Test_ApplyWorkflowState_ExistWorkflow_SyncHSM( BranchId: uuid.New(), Ancestors: nil, } - historyBranch, err := serialization.HistoryBranchToBlob(branchInfo) + historyBranch, err := s.serializer.HistoryBranchToBlob(branchInfo) s.NoError(err) completionEventBatchId := int64(5) nextEventID := int64(7) @@ -596,7 +597,7 @@ func (s *workflowReplicatorSuite) Test_ReplicateVersionedTransition_SameBranch_S s.mockShard, s.mockWorkflowCache, nil, - serialization.NewSerializer(), + s.serializer, s.logger, ) mockTransactionManager := NewMockTransactionManager(s.controller) @@ -687,7 +688,7 @@ func (s *workflowReplicatorSuite) Test_ReplicateVersionedTransition_DifferentBra s.mockShard, s.mockWorkflowCache, nil, - serialization.NewSerializer(), + s.serializer, s.logger, ) mockTransactionManager := NewMockTransactionManager(s.controller) @@ -772,7 +773,7 @@ func (s *workflowReplicatorSuite) Test_ReplicateVersionedTransition_SameBranch_S s.mockShard, s.mockWorkflowCache, nil, - serialization.NewSerializer(), + s.serializer, s.logger, ) mockTransactionManager := NewMockTransactionManager(s.controller) @@ -866,7 +867,7 @@ func (s *workflowReplicatorSuite) Test_ReplicateVersionedTransition_FirstTask_Sy s.mockShard, s.mockWorkflowCache, nil, - serialization.NewSerializer(), + s.serializer, s.logger, ) mockTransactionManager := NewMockTransactionManager(s.controller) @@ -945,7 +946,7 @@ func (s *workflowReplicatorSuite) Test_ReplicateVersionedTransition_MutationProv s.mockShard, s.mockWorkflowCache, nil, - serialization.NewSerializer(), + s.serializer, s.logger, ) mockTransactionManager := NewMockTransactionManager(s.controller) @@ -1084,7 +1085,6 @@ func (s *workflowReplicatorSuite) Test_bringLocalEventsUpToSourceCurrentBranch_W }, }, } - serializer := serialization.NewSerializer() gapEvents := []*historypb.HistoryEvent{ {EventId: 21, Version: 2}, {EventId: 22, Version: 2}, {EventId: 23, Version: 2}, {EventId: 24, Version: 2}, @@ -1096,11 +1096,11 @@ func (s *workflowReplicatorSuite) Test_bringLocalEventsUpToSourceCurrentBranch_W {EventId: 27, Version: 2}, {EventId: 28, Version: 2}, {EventId: 29, Version: 2}, {EventId: 30, Version: 2}, } - blobs, err := serializer.SerializeEvents(requestedEvents) + blobs, err := s.serializer.SerializeEvents(requestedEvents) s.NoError(err) - gapBlobs, err := serializer.SerializeEvents(gapEvents) + gapBlobs, err := s.serializer.SerializeEvents(gapEvents) s.NoError(err) - tailBlobs, err := serializer.SerializeEvents(tailEvents) + tailBlobs, err := s.serializer.SerializeEvents(tailEvents) s.NoError(err) mockMutableState := historyi.NewMockMutableState(s.controller) mockMutableState.EXPECT().GetExecutionInfo().Return(&persistencespb.WorkflowExecutionInfo{ @@ -1239,7 +1239,6 @@ func (s *workflowReplicatorSuite) Test_bringLocalEventsUpToSourceCurrentBranch_W }, }, } - serializer := serialization.NewSerializer() gapEvents := []*historypb.HistoryEvent{ {EventId: 1, Version: 1}, {EventId: 2, Version: 1}, {EventId: 3, Version: 1}, } @@ -1249,11 +1248,11 @@ func (s *workflowReplicatorSuite) Test_bringLocalEventsUpToSourceCurrentBranch_W tailEvents := []*historypb.HistoryEvent{ {EventId: 5, Version: 2}, {EventId: 6, Version: 2}, } - blobs, err := serializer.SerializeEvents(requestedEvents) + blobs, err := s.serializer.SerializeEvents(requestedEvents) s.NoError(err) - gapBlobs, err := serializer.SerializeEvents(gapEvents) + gapBlobs, err := s.serializer.SerializeEvents(gapEvents) s.NoError(err) - tailBlobs, err := serializer.SerializeEvents(tailEvents) + tailBlobs, err := s.serializer.SerializeEvents(tailEvents) s.NoError(err) mockMutableState := historyi.NewMockMutableState(s.controller) mockMutableState.EXPECT().GetExecutionInfo().Return(&persistencespb.WorkflowExecutionInfo{ @@ -1405,7 +1404,6 @@ func (s *workflowReplicatorSuite) Test_bringLocalEventsUpToSourceCurrentBranch_W }, }, } - serializer := serialization.NewSerializer() gapEvents := []*historypb.HistoryEvent{ {EventId: 21, Version: 2}, {EventId: 22, Version: 2}, {EventId: 23, Version: 2}, {EventId: 24, Version: 2}, @@ -1417,11 +1415,11 @@ func (s *workflowReplicatorSuite) Test_bringLocalEventsUpToSourceCurrentBranch_W {EventId: 28, Version: 2}, {EventId: 29, Version: 2}, {EventId: 30, Version: 2}, } - blobs, err := serializer.SerializeEvents(requestedEvents) + blobs, err := s.serializer.SerializeEvents(requestedEvents) s.NoError(err) - gapBlobs, err := serializer.SerializeEvents(gapEvents) + gapBlobs, err := s.serializer.SerializeEvents(gapEvents) s.NoError(err) - tailBlobs, err := serializer.SerializeEvents(tailEvents) + tailBlobs, err := s.serializer.SerializeEvents(tailEvents) s.NoError(err) mockMutableState := historyi.NewMockMutableState(s.controller) mockMutableState.EXPECT().GetExecutionInfo().Return(&persistencespb.WorkflowExecutionInfo{ diff --git a/service/history/queues/queue_base_test.go b/service/history/queues/queue_base_test.go index 4b3c76b01f0..e3b6ddaa6dd 100644 --- a/service/history/queues/queue_base_test.go +++ b/service/history/queues/queue_base_test.go @@ -516,14 +516,15 @@ func (s *queueBaseSuite) QueueStateEqual( that *persistencespb.QueueState, ) { // ser/de so to equal will not take timezone into consideration - thisBlob, err := serialization.QueueStateToBlob(this) + serializer := serialization.NewSerializer() + thisBlob, err := serializer.QueueStateToBlob(this) s.NoError(err) - this, err = serialization.QueueStateFromBlob(thisBlob) + this, err = serializer.QueueStateFromBlob(thisBlob) s.NoError(err) - thatBlob, err := serialization.QueueStateToBlob(that) + thatBlob, err := serializer.QueueStateToBlob(that) s.NoError(err) - that, err = serialization.QueueStateFromBlob(thatBlob) + that, err = serializer.QueueStateFromBlob(thatBlob) s.NoError(err) s.Equal(this, that) diff --git a/service/history/replication/dlq_writer.go b/service/history/replication/dlq_writer.go index 085a71cfc94..0dee8a97283 100644 --- a/service/history/replication/dlq_writer.go +++ b/service/history/replication/dlq_writer.go @@ -5,9 +5,9 @@ import ( persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/service/history/configs" "go.temporal.io/server/service/history/queues" - "go.temporal.io/server/service/history/tasks" "go.uber.org/fx" ) @@ -39,16 +39,11 @@ type ( executionManagerDLQWriter struct { executionManager ExecutionManager } - // TaskParser is a trimmed version of [go.temporal.io/server/common/persistence/serialization.Serializer] - // that only provides the methods we need. - TaskParser interface { - ParseReplicationTask(replicationTask *persistencespb.ReplicationTaskInfo) (tasks.Task, error) - } // DLQWriterAdapter is a [DLQWriter] that uses the QueueV2 [queues.DLQWriter] object. DLQWriterAdapter struct { - dlqWriter *queues.DLQWriter - taskParser TaskParser - currentClusterName string + dlqWriter *queues.DLQWriter + replicationTaskSerializer serialization.ReplicationTaskSerializer + currentClusterName string } dlqWriterToggleParams struct { fx.In @@ -71,13 +66,13 @@ func NewExecutionManagerDLQWriter(executionManager ExecutionManager) *executionM // NewDLQWriterAdapter creates a new DLQWriter from a QueueV2 [queues.DLQWriter]. func NewDLQWriterAdapter( dlqWriter *queues.DLQWriter, - taskParser TaskParser, + replicationTaskSerializer serialization.ReplicationTaskSerializer, currentClusterName string, ) *DLQWriterAdapter { return &DLQWriterAdapter{ - dlqWriter: dlqWriter, - taskParser: taskParser, - currentClusterName: currentClusterName, + dlqWriter: dlqWriter, + replicationTaskSerializer: replicationTaskSerializer, + currentClusterName: currentClusterName, } } @@ -120,7 +115,7 @@ func (d *DLQWriterAdapter) WriteTaskToDLQ( ctx context.Context, request DLQWriteRequest, ) error { - task, err := d.taskParser.ParseReplicationTask(request.ReplicationTaskInfo) + task, err := d.replicationTaskSerializer.DeserializeReplicationTask(request.ReplicationTaskInfo) if err != nil { return err } diff --git a/service/history/replication/dlq_writer_test.go b/service/history/replication/dlq_writer_test.go index d0e4da87fca..5d48cebee2a 100644 --- a/service/history/replication/dlq_writer_test.go +++ b/service/history/replication/dlq_writer_test.go @@ -72,7 +72,7 @@ func TestNewDLQWriterAdapter(t *testing.T) { t.Run(tc.name, func(t *testing.T) { controller := gomock.NewController(t) queueWriter := &queuestest.FakeQueueWriter{} - taskSerializer := serialization.NewTaskSerializer() + taskSerializer := serialization.NewTaskSerializer(serialization.NewSerializer()) namespaceRegistry := namespace.NewMockRegistry(controller) namespaceRegistry.EXPECT().GetNamespaceByID(gomock.Any()).Return(&namespace.Namespace{}, nil).AnyTimes() metricsHandler := metricstest.NewCaptureHandler() diff --git a/service/history/replication/executable_backfill_history_events_task.go b/service/history/replication/executable_backfill_history_events_task.go index 73099a23438..910fd61074d 100644 --- a/service/history/replication/executable_backfill_history_events_task.go +++ b/service/history/replication/executable_backfill_history_events_task.go @@ -226,7 +226,7 @@ func (e *ExecutableBackfillHistoryEventsTask) HandleErr(err error) error { func (e *ExecutableBackfillHistoryEventsTask) getDeserializedEvents() (_ [][]*historypb.HistoryEvent, _ []*historypb.HistoryEvent, retError error) { eventBatches := [][]*historypb.HistoryEvent{} for _, eventsBlob := range e.taskAttr.EventBatches { - events, err := e.EventSerializer.DeserializeEvents(eventsBlob) + events, err := e.Serializer.DeserializeEvents(eventsBlob) if err != nil { e.Logger.Error("unable to deserialize history events", tag.WorkflowNamespaceID(e.NamespaceID), @@ -240,7 +240,7 @@ func (e *ExecutableBackfillHistoryEventsTask) getDeserializedEvents() (_ [][]*hi eventBatches = append(eventBatches, events) } - newRunEvents, err := e.EventSerializer.DeserializeEvents(e.taskAttr.NewRunInfo.EventBatch) + newRunEvents, err := e.Serializer.DeserializeEvents(e.taskAttr.NewRunInfo.EventBatch) if err != nil { e.Logger.Error("unable to deserialize new run history events", tag.WorkflowNamespaceID(e.NamespaceID), diff --git a/service/history/replication/executable_backfill_history_events_task_test.go b/service/history/replication/executable_backfill_history_events_task_test.go index d13d101ff7e..70d6c16bb4c 100644 --- a/service/history/replication/executable_backfill_history_events_task_test.go +++ b/service/history/replication/executable_backfill_history_events_task_test.go @@ -151,7 +151,7 @@ func (s *executableBackfillHistoryEventsTaskSuite) SetupTest() { NamespaceCache: s.namespaceCache, MetricsHandler: s.metricsHandler, Logger: s.logger, - EventSerializer: s.eventSerializer, + Serializer: s.eventSerializer, EagerNamespaceRefresher: s.eagerNamespaceRefresher, DLQWriter: NewExecutionManagerDLQWriter(s.mockExecutionManager), Config: s.config, diff --git a/service/history/replication/executable_history_task.go b/service/history/replication/executable_history_task.go index 46ecc00d2a9..aa64c07c20f 100644 --- a/service/history/replication/executable_history_task.go +++ b/service/history/replication/executable_history_task.go @@ -192,7 +192,7 @@ func (e *ExecutableHistoryTask) MarkPoisonPill() error { if e.ReplicationTask().GetRawTaskInfo() == nil { eventBatches := [][]*historypb.HistoryEvent{} for _, eventsBlob := range e.eventsBlobs { - events, err := e.EventSerializer.DeserializeEvents(eventsBlob) + events, err := e.Serializer.DeserializeEvents(eventsBlob) if err != nil { e.Logger.Error("unable to enqueue history replication task to DLQ, ser/de error", tag.ShardID(shardContext.GetShardID()), @@ -256,7 +256,7 @@ func (e *ExecutableHistoryTask) getDeserializedEvents() (_ [][]*historypb.Histor eventBatches := [][]*historypb.HistoryEvent{} for _, eventsBlob := range e.eventsBlobs { - events, err := e.EventSerializer.DeserializeEvents(eventsBlob) + events, err := e.Serializer.DeserializeEvents(eventsBlob) if err != nil { e.Logger.Error("unable to deserialize history events", tag.WorkflowNamespaceID(e.NamespaceID), @@ -270,7 +270,7 @@ func (e *ExecutableHistoryTask) getDeserializedEvents() (_ [][]*historypb.Histor eventBatches = append(eventBatches, events) } - newRunEvents, err := e.EventSerializer.DeserializeEvents(e.newRunEventsBlob) + newRunEvents, err := e.Serializer.DeserializeEvents(e.newRunEventsBlob) if err != nil { e.Logger.Error("unable to deserialize new run history events", tag.WorkflowNamespaceID(e.NamespaceID), diff --git a/service/history/replication/executable_history_task_test.go b/service/history/replication/executable_history_task_test.go index 0e483b3d728..ea1ef3e407f 100644 --- a/service/history/replication/executable_history_task_test.go +++ b/service/history/replication/executable_history_task_test.go @@ -123,7 +123,7 @@ func (s *executableHistoryTaskSuite) SetupTest() { MetricsHandler: s.metricsHandler, Logger: s.logger, EagerNamespaceRefresher: s.eagerNamespaceRefresher, - EventSerializer: s.eventSerializer, + Serializer: s.eventSerializer, DLQWriter: NewExecutionManagerDLQWriter(s.mockExecutionManager), Config: tests.NewDynamicConfig(), HistoryEventsHandler: s.mockEventHandler, diff --git a/service/history/replication/executable_task.go b/service/history/replication/executable_task.go index a2d03809f6e..6873fb278bb 100644 --- a/service/history/replication/executable_task.go +++ b/service/history/replication/executable_task.go @@ -23,7 +23,6 @@ import ( "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" - "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/versionhistory" serviceerrors "go.temporal.io/server/common/serviceerror" ctasks "go.temporal.io/server/common/tasks" @@ -521,7 +520,7 @@ func (e *ExecutableTaskImpl) BackFillEvents( if err != nil { return serviceerror.NewInternalf("failed to get new run history when backfill: %v", err) } - events, err := e.EventSerializer.DeserializeEvents(batch.RawEventBatch) + events, err := e.Serializer.DeserializeEvents(batch.RawEventBatch) if err != nil { return serviceerror.NewInternalf("failed to deserailize run history events when backfill: %v", err) } @@ -565,7 +564,7 @@ func (e *ExecutableTaskImpl) BackFillEvents( if err != nil { return err } - events, err := e.EventSerializer.DeserializeEvents(batch.RawEventBatch) + events, err := e.Serializer.DeserializeEvents(batch.RawEventBatch) if err != nil { return err } @@ -656,7 +655,7 @@ func (e *ExecutableTaskImpl) SyncState( tasksToAdd := make([]*adminservice.AddTasksRequest_Task, 0, len(taskEquivalents)) for _, taskEquivalent := range taskEquivalents { - blob, err := serialization.ReplicationTaskInfoToBlob(taskEquivalent) + blob, err := e.Serializer.ReplicationTaskInfoToBlob(taskEquivalent) if err != nil { return false, err } diff --git a/service/history/replication/executable_task_test.go b/service/history/replication/executable_task_test.go index 5332937f025..486c58065d6 100644 --- a/service/history/replication/executable_task_test.go +++ b/service/history/replication/executable_task_test.go @@ -127,7 +127,7 @@ func (s *executableTaskSuite) SetupTest() { Logger: s.logger, EagerNamespaceRefresher: s.eagerNamespaceRefresher, DLQWriter: NewExecutionManagerDLQWriter(s.mockExecutionManager), - EventSerializer: s.serializer, + Serializer: s.serializer, RemoteHistoryFetcher: s.remoteHistoryFetcher, } diff --git a/service/history/replication/executable_task_tool_box.go b/service/history/replication/executable_task_tool_box.go index d51e36c7f12..17625b18c3a 100644 --- a/service/history/replication/executable_task_tool_box.go +++ b/service/history/replication/executable_task_tool_box.go @@ -31,7 +31,7 @@ type ( LowPriorityTaskScheduler ctasks.Scheduler[TrackableExecutableTask] `name:"LowPriorityTaskScheduler"` MetricsHandler metrics.Handler Logger log.Logger - EventSerializer serialization.Serializer + Serializer serialization.Serializer DLQWriter DLQWriter HistoryEventsHandler eventhandler.HistoryEventsHandler WorkflowCache wcache.Cache diff --git a/service/history/replication/executable_verify_versioned_transition_task_test.go b/service/history/replication/executable_verify_versioned_transition_task_test.go index 7322aeea361..e4d105cab6f 100644 --- a/service/history/replication/executable_verify_versioned_transition_task_test.go +++ b/service/history/replication/executable_verify_versioned_transition_task_test.go @@ -112,7 +112,7 @@ func (s *executableVerifyVersionedTransitionTaskSuite) SetupTest() { NamespaceCache: s.namespaceCache, MetricsHandler: s.metricsHandler, Logger: s.logger, - EventSerializer: s.eventSerializer, + Serializer: s.eventSerializer, EagerNamespaceRefresher: s.eagerNamespaceRefresher, DLQWriter: NewExecutionManagerDLQWriter(s.mockExecutionManager), Config: s.config, diff --git a/service/history/replication/fx.go b/service/history/replication/fx.go index f05ed3c0478..09c23069da7 100644 --- a/service/history/replication/fx.go +++ b/service/history/replication/fx.go @@ -41,6 +41,9 @@ var Module = fx.Provide( NewExecutionManagerDLQWriter, ClientSchedulerRateLimiterProvider, ServerSchedulerRateLimiterProvider, + func(serializer serialization.Serializer) serialization.ReplicationTaskSerializer { + return serialization.NewTaskSerializer(serializer) + }, replicationTaskConverterFactoryProvider, replicationTaskExecutorProvider, fx.Annotated{ @@ -335,10 +338,10 @@ func eventImporterProvider( func dlqWriterAdapterProvider( dlqWriter *queues.DLQWriter, - taskSerializer serialization.Serializer, + replicationTaskSerializer serialization.ReplicationTaskSerializer, clusterMetadata cluster.Metadata, ) *DLQWriterAdapter { - return NewDLQWriterAdapter(dlqWriter, taskSerializer, clusterMetadata.GetCurrentClusterName()) + return NewDLQWriterAdapter(dlqWriter, replicationTaskSerializer, clusterMetadata.GetCurrentClusterName()) } func historyEventsHandlerProvider( diff --git a/service/history/replication/raw_task_converter.go b/service/history/replication/raw_task_converter.go index 4a81093a89c..d2ea27aadc7 100644 --- a/service/history/replication/raw_task_converter.go +++ b/service/history/replication/raw_task_converter.go @@ -35,10 +35,11 @@ import ( type ( SourceTaskConverterImpl struct { - historyEngine historyi.Engine - namespaceCache namespace.Registry - serializer serialization.Serializer - config *configs.Config + historyEngine historyi.Engine + namespaceCache namespace.Registry + serializer serialization.Serializer + replicationTaskSerializer serialization.ReplicationTaskSerializer + config *configs.Config } SourceTaskConverter interface { Convert(task tasks.Task, targetClusterID int32, priority enumsspb.TaskPriority) (*replicationspb.ReplicationTask, error) @@ -69,10 +70,11 @@ func NewSourceTaskConverter( config *configs.Config, ) *SourceTaskConverterImpl { return &SourceTaskConverterImpl{ - historyEngine: historyEngine, - namespaceCache: namespaceCache, - serializer: serializer, - config: config, + historyEngine: historyEngine, + namespaceCache: namespaceCache, + serializer: serializer, + replicationTaskSerializer: serialization.NewTaskSerializer(serializer), + config: config, } } @@ -103,7 +105,7 @@ func (c *SourceTaskConverterImpl) Convert( return nil, err } if replicationTask != nil { - rawTaskInfo, err := c.serializer.ParseReplicationTaskInfo(task) + rawTaskInfo, err := c.replicationTaskSerializer.SeralizeReplicationTask(task) if err != nil { return nil, err } diff --git a/service/history/shard/context_impl.go b/service/history/shard/context_impl.go index 582b2777fe6..93b3aab328b 100644 --- a/service/history/shard/context_impl.go +++ b/service/history/shard/context_impl.go @@ -359,8 +359,8 @@ func (s *ContextImpl) GetQueueState( return nil, false } // need to make a deep copy, in case UpdateReplicationQueueReaderState does a partial update - blob, _ := serialization.QueueStateToBlob(queueState) - queueState, _ = serialization.QueueStateFromBlob(blob) + blob, _ := s.payloadSerializer.QueueStateToBlob(queueState) + queueState, _ = s.payloadSerializer.QueueStateFromBlob(blob) return queueState, ok } @@ -1168,7 +1168,7 @@ func (s *ContextImpl) renewRangeLocked(isStealing bool) error { // before calling this method. s.taskKeyManager.drainTaskRequests() - updatedShardInfo := trimShardInfo(s.clusterMetadata.GetAllClusterInfo(), copyShardInfo(s.shardInfo)) + updatedShardInfo := trimShardInfo(s.clusterMetadata.GetAllClusterInfo(), s.copyShardInfo(s.shardInfo)) updatedShardInfo.RangeId++ if isStealing { updatedShardInfo.StolenSinceRenew++ @@ -1199,7 +1199,7 @@ func (s *ContextImpl) renewRangeLocked(isStealing bool) error { tag.PreviousShardRangeID(s.shardInfo.RangeId), ) - s.shardInfo = trimShardInfo(s.clusterMetadata.GetAllClusterInfo(), copyShardInfo(updatedShardInfo)) + s.shardInfo = trimShardInfo(s.clusterMetadata.GetAllClusterInfo(), s.copyShardInfo(updatedShardInfo)) s.taskKeyManager.setRangeID(s.shardInfo.RangeId) return nil @@ -1256,7 +1256,7 @@ func (s *ContextImpl) updateShardInfo( s.lastUpdated = now s.tasksCompletedSinceLastUpdate = 0 - updatedShardInfo := trimShardInfo(s.clusterMetadata.GetAllClusterInfo(), copyShardInfo(s.shardInfo)) + updatedShardInfo := trimShardInfo(s.clusterMetadata.GetAllClusterInfo(), s.copyShardInfo(s.shardInfo)) request := &persistence.UpdateShardRequest{ ShardInfo: updatedShardInfo, PreviousRangeID: s.shardInfo.GetRangeId(), @@ -1289,7 +1289,7 @@ func (s *ContextImpl) emitShardInfoMetricsLogs() { s.rLock() defer s.rUnlock() - queueStates := trimShardInfo(s.clusterMetadata.GetAllClusterInfo(), copyShardInfo(s.shardInfo)).QueueStates + queueStates := trimShardInfo(s.clusterMetadata.GetAllClusterInfo(), s.copyShardInfo(s.shardInfo)).QueueStates emitShardLagLog := s.config.EmitShardLagLog() metricsHandler := s.GetMetricsHandler().WithTags(metrics.OperationTag(metrics.ShardInfoScope)) @@ -1802,7 +1802,7 @@ func (s *ContextImpl) loadShardMetadata(ownershipChanged *bool) error { return err } *ownershipChanged = resp.ShardInfo.Owner != s.owner - shardInfo := trimShardInfo(s.clusterMetadata.GetAllClusterInfo(), copyShardInfo(resp.ShardInfo)) + shardInfo := trimShardInfo(s.clusterMetadata.GetAllClusterInfo(), s.copyShardInfo(resp.ShardInfo)) shardInfo.Owner = s.owner // initialize the cluster current time to be the same as ack level @@ -2156,12 +2156,12 @@ func (s *ContextImpl) initLastUpdatesTime() { } // TODO: why do we need a deep copy here? -func copyShardInfo(shardInfo *persistencespb.ShardInfo) *persistencespb.ShardInfo { +func (s *ContextImpl) copyShardInfo(shardInfo *persistencespb.ShardInfo) *persistencespb.ShardInfo { // need to ser/de to make a deep copy of queue state queueStates := make(map[int32]*persistencespb.QueueState, len(shardInfo.QueueStates)) for k, v := range shardInfo.QueueStates { - blob, _ := serialization.QueueStateToBlob(v) - queueState, _ := serialization.QueueStateFromBlob(blob) + blob, _ := s.payloadSerializer.QueueStateToBlob(v) + queueState, _ := s.payloadSerializer.QueueStateFromBlob(blob) queueStates[k] = queueState } diff --git a/service/history/workflow/workflow_test/mutable_state_impl_test.go b/service/history/workflow/workflow_test/mutable_state_impl_test.go index dcc11807c56..83091f7d759 100644 --- a/service/history/workflow/workflow_test/mutable_state_impl_test.go +++ b/service/history/workflow/workflow_test/mutable_state_impl_test.go @@ -28,6 +28,7 @@ import ( "go.temporal.io/server/common/log" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/versionhistory" "go.temporal.io/server/service/history/configs" "go.temporal.io/server/service/history/events" @@ -210,7 +211,7 @@ func createMutableState(t *testing.T, nsEntry *namespace.Namespace, cfg *configs clusterMetadata.EXPECT().GetClusterID().Return(int64(1)).AnyTimes() executionManager := shardContext.Resource.ExecutionMgr - executionManager.EXPECT().GetHistoryBranchUtil().Return(&persistence.HistoryBranchUtilImpl{}).AnyTimes() + executionManager.EXPECT().GetHistoryBranchUtil().Return(persistence.NewHistoryBranchUtil(serialization.NewSerializer())).AnyTimes() startTime := time.Time{} logger := log.NewNoopLogger() diff --git a/service/worker/dlq/workflow_test.go b/service/worker/dlq/workflow_test.go index bf7e74d8b0b..d3f0cca234c 100644 --- a/service/worker/dlq/workflow_test.go +++ b/service/worker/dlq/workflow_test.go @@ -376,7 +376,7 @@ func TestModule(t *testing.T) { params.workflowParams.MergeParams.Key.TaskCategoryID = tasks.CategoryIDReplication params.expectedQueryResp.DlqKey = params.workflowParams.MergeParams.Key var replicationTask tasks.HistoryReplicationTask - blob, err := serialization.NewTaskSerializer().SerializeTask(&replicationTask) + blob, err := serialization.NewTaskSerializer(serialization.NewSerializer()).SerializeTask(&replicationTask) require.NoError(t, err) params.client.getTasksFn = func(req *historyservice.GetDLQTasksRequest) (*historyservice.GetDLQTasksResponse, error) { return &historyservice.GetDLQTasksResponse{ diff --git a/service/worker/scanner/history/scavenger.go b/service/worker/scanner/history/scavenger.go index f1eb319f855..91207670ac5 100644 --- a/service/worker/scanner/history/scavenger.go +++ b/service/worker/scanner/history/scavenger.go @@ -46,6 +46,7 @@ type ( rateLimiter quotas.RateLimiter metricsHandler metrics.Handler logger log.Logger + serializer serialization.Serializer isInTest bool // only clean up history branches that older than this age // Our history archiver delete mutable state, and then upload history to blob store and then delete history. @@ -92,8 +93,8 @@ func NewScavenger( enableRetentionVerification dynamicconfig.BoolPropertyFn, metricsHandler metrics.Handler, logger log.Logger, + serializer serialization.Serializer, ) *Scavenger { - return &Scavenger{ numShards: numShards, db: db, @@ -108,8 +109,8 @@ func NewScavenger( enableRetentionVerification: enableRetentionVerification, metricsHandler: metricsHandler.WithTags(metrics.OperationTag(metrics.HistoryScavengerScope)), logger: logger, - - hbd: hbd, + serializer: serializer, + hbd: hbd, } } @@ -226,7 +227,7 @@ func (s *Scavenger) filterTask( } shardID := common.WorkflowIDToHistoryShard(namespaceID, workflowID, s.numShards) - branchToken, err := serialization.HistoryBranchToBlob(branch.BranchInfo) + branchToken, err := s.serializer.HistoryBranchToBlob(branch.BranchInfo) if err != nil { s.logger.Error("unable to serialize the history branch token", tag.DetailInfo(branch.Info), tag.Error(err)) metrics.HistoryScavengerErrorCount.With(s.metricsHandler).Record(1) diff --git a/service/worker/scanner/history/scavenger_test.go b/service/worker/scanner/history/scavenger_test.go index c5eea8480bd..4cbeafd574a 100644 --- a/service/worker/scanner/history/scavenger_test.go +++ b/service/worker/scanner/history/scavenger_test.go @@ -22,6 +22,7 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/testing/protomock" @@ -102,8 +103,10 @@ func (s *ScavengerTestSuite) createTestScavenger( enableRetentionVerification, s.metricHandler, s.logger, + serialization.NewSerializer(), ) s.scavenger.isInTest = true + s.historyBranchUtil = *persistence.NewHistoryBranchUtil(serialization.NewSerializer()) } func (s *ScavengerTestSuite) TestAllSkipTasksTwoPages() { diff --git a/service/worker/scanner/scanner.go b/service/worker/scanner/scanner.go index 3bd067a2251..05f944afcb6 100644 --- a/service/worker/scanner/scanner.go +++ b/service/worker/scanner/scanner.go @@ -24,6 +24,7 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/sdk" "go.temporal.io/server/service/worker/scanner/build_ids" @@ -89,6 +90,7 @@ type ( namespaceRegistry namespace.Registry currentClusterName string hostInfo membership.HostInfo + serializer serialization.Serializer } // Scanner is the background sub-system that does full scans @@ -121,6 +123,7 @@ func New( registry namespace.Registry, currentClusterName string, hostInfo membership.HostInfo, + serializer serialization.Serializer, ) *Scanner { return &Scanner{ context: scannerContext{ @@ -138,6 +141,7 @@ func New( namespaceRegistry: registry, currentClusterName: currentClusterName, hostInfo: hostInfo, + serializer: serializer, }, } } diff --git a/service/worker/scanner/scanner_test.go b/service/worker/scanner/scanner_test.go index e89d4c03fab..19f60bb5071 100644 --- a/service/worker/scanner/scanner_test.go +++ b/service/worker/scanner/scanner_test.go @@ -16,6 +16,7 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" p "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/testing/mocksdk" "go.temporal.io/server/service/worker/scanner/build_ids" @@ -213,6 +214,7 @@ func (s *scannerTestSuite) TestScannerEnabled() { mockNamespaceRegistry, "active-cluster", membership.NewHostInfoFromAddress("localhost"), + serialization.NewSerializer(), ) var wg sync.WaitGroup for _, sc := range c.ExpectedScanners { @@ -289,6 +291,7 @@ func (s *scannerTestSuite) TestScannerShutdown() { mockNamespaceRegistry, "active-cluster", membership.NewHostInfoFromAddress("localhost"), + serialization.NewSerializer(), ) mockSdkClientFactory.EXPECT().GetSystemClient().Return(mockSdkClient).AnyTimes() worker.EXPECT().RegisterActivityWithOptions(gomock.Any(), gomock.Any()).AnyTimes() diff --git a/service/worker/scanner/workflow.go b/service/worker/scanner/workflow.go index 6681729ba76..f3aa190d4a9 100644 --- a/service/worker/scanner/workflow.go +++ b/service/worker/scanner/workflow.go @@ -133,6 +133,7 @@ func HistoryScavengerActivity( ctx.cfg.HistoryScannerVerifyRetention, ctx.metricsHandler, ctx.logger, + ctx.serializer, ) return scavenger.Run(activityCtx) } diff --git a/service/worker/service.go b/service/worker/service.go index a153efe93e5..6ef02b6113d 100644 --- a/service/worker/service.go +++ b/service/worker/service.go @@ -17,6 +17,7 @@ import ( "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/namespace/nsreplication" "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/resource" @@ -111,6 +112,7 @@ func NewService( visibilityManager manager.VisibilityManager, matchingClient resource.MatchingClient, namespaceReplicationTaskExecutor nsreplication.TaskExecutor, + serializer serialization.Serializer, ) (*Service, error) { workerServiceResolver, err := membershipMonitor.GetResolver(primitives.WorkerService) if err != nil { @@ -141,7 +143,7 @@ func NewService( matchingClient: matchingClient, namespaceReplicationTaskExecutor: namespaceReplicationTaskExecutor, } - if err := s.initScanner(); err != nil { + if err := s.initScanner(serializer); err != nil { return nil, err } return s, nil @@ -282,7 +284,7 @@ func (s *Service) startParentClosePolicyProcessor() { } } -func (s *Service) initScanner() error { +func (s *Service) initScanner(serializer serialization.Serializer) error { currentCluster := s.clusterMetadata.GetCurrentClusterName() adminClient, err := s.clientBean.GetRemoteAdminClient(currentCluster) if err != nil { @@ -303,6 +305,7 @@ func (s *Service) initScanner() error { s.namespaceRegistry, currentCluster, s.hostInfo, + serializer, ) return nil } diff --git a/temporal/fx.go b/temporal/fx.go index a21664003cc..91010b200e6 100644 --- a/temporal/fx.go +++ b/temporal/fx.go @@ -35,6 +35,7 @@ import ( "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/cassandra" persistenceClient "go.temporal.io/server/common/persistence/client" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/sql" "go.temporal.io/server/common/persistence/visibility" esclient "go.temporal.io/server/common/persistence/visibility/store/elasticsearch/client" @@ -584,6 +585,7 @@ func ApplyClusterMetadataConfigProvider( logger = log.With(logger, tag.ComponentMetadataInitializer) metricsHandler = metricsHandler.WithTags(metrics.ServiceNameTag(primitives.ServerService)) clusterName := persistenceClient.ClusterName(svc.ClusterMetadata.CurrentClusterName) + serializer := serialization.NewSerializer() dataStoreFactory := persistenceClient.DataStoreFactoryProvider( clusterName, persistenceServiceResolver, @@ -592,6 +594,7 @@ func ApplyClusterMetadataConfigProvider( logger, metricsHandler, telemetry.NoopTracerProvider, + serializer, ) factory := persistenceFactoryProvider(persistenceClient.NewFactoryParams{ DataStoreFactory: dataStoreFactory, @@ -601,6 +604,7 @@ func ApplyClusterMetadataConfigProvider( ClusterName: persistenceClient.ClusterName(svc.ClusterMetadata.CurrentClusterName), MetricsHandler: metricsHandler, Logger: logger, + Serializer: serializer, }) defer factory.Close() diff --git a/temporal/server_impl.go b/temporal/server_impl.go index 03d0bd17ffa..13767faaf1a 100644 --- a/temporal/server_impl.go +++ b/temporal/server_impl.go @@ -15,6 +15,7 @@ import ( "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" persistenceClient "go.temporal.io/server/common/persistence/client" + "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/resolver" "go.temporal.io/server/common/resource" @@ -151,6 +152,7 @@ func initSystemNamespaces( ) error { clusterName := persistenceClient.ClusterName(currentClusterName) metricsHandler = metricsHandler.WithTags(metrics.ServiceNameTag(primitives.ServerService)) + serializer := serialization.NewSerializer() dataStoreFactory := persistenceClient.DataStoreFactoryProvider( clusterName, persistenceServiceResolver, @@ -159,6 +161,7 @@ func initSystemNamespaces( logger, metricsHandler, telemetry.NoopTracerProvider, + serializer, ) factory := persistenceFactoryProvider(persistenceClient.NewFactoryParams{ DataStoreFactory: dataStoreFactory, @@ -168,6 +171,7 @@ func initSystemNamespaces( ClusterName: persistenceClient.ClusterName(currentClusterName), MetricsHandler: metricsHandler, Logger: logger, + Serializer: serializer, }) defer factory.Close() diff --git a/temporal/server_test.go b/temporal/server_test.go index 07172d71132..df00b21fe23 100644 --- a/temporal/server_test.go +++ b/temporal/server_test.go @@ -3,6 +3,7 @@ package temporal_test import ( "context" "fmt" + "os" "path" "strings" "sync/atomic" @@ -11,21 +12,30 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/worker" + "go.temporal.io/sdk/workflow" + "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/common/config" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/persistence/serialization" _ "go.temporal.io/server/common/persistence/sql/sqlplugin/sqlite" // needed to register the sqlite plugin "go.temporal.io/server/common/testing/testtelemetry" "go.temporal.io/server/service/frontend" "go.temporal.io/server/temporal" "go.temporal.io/server/tests/testutils" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/types/known/durationpb" ) // TestNewServer verifies that NewServer doesn't cause any fx errors, and that there are no unexpected error logs after // running for a few seconds. func TestNewServer(t *testing.T) { - startAndStopServer(t) + startAndVerifyServer(t) } // TestNewServerWithOTEL verifies that NewServer doesn't cause any fx errors when OTEL is enabled. @@ -38,21 +48,24 @@ func TestNewServerWithOTEL(t *testing.T) { collector, err := testtelemetry.StartMemoryCollector(t) require.NoError(t, err) t.Setenv("OTEL_EXPORTER_OTLP_ENDPOINT", collector.Addr()) - startAndStopServer(t) + startAndVerifyServer(t) require.NotEmpty(t, collector.Spans(), "expected at least one OTEL span") }) t.Run("without OTEL Collector running", func(t *testing.T) { - startAndStopServer(t) + startAndVerifyServer(t) }) } -func startAndStopServer(t *testing.T) { +// TestNewServerWithJSONEncoding verifies that NewServer doesn't cause any fx errors when JSON encoding is enabled. +func TestNewServerWithJSONEncoding(t *testing.T) { + t.Setenv(serialization.SerializerDataEncodingEnvVar, "json") + startAndVerifyServer(t) +} + +func startAndVerifyServer(t *testing.T) { cfg := loadConfig(t) - // The prometheus reporter does not shut down in-between test runs. - // This will assign a random port to the prometheus reporter, - // so that it doesn't conflict with other tests. - cfg.Global.Metrics.Prometheus.ListenAddress = ":0" + logDetector := newErrorLogDetector(t, log.NewTestLogger()) logDetector.Start() @@ -70,13 +83,70 @@ func startAndStopServer(t *testing.T) { }) require.NoError(t, server.Start()) - time.Sleep(10 * time.Second) //nolint:forbidigo + + // Create SDK client/ + frontendHostPort := fmt.Sprintf("127.0.0.1:%d", cfg.Services["frontend"].RPC.GRPCPort) + c, err := client.Dial(client.Options{ + HostPort: frontendHostPort, + Namespace: "default", + }) + require.NoError(t, err) + defer c.Close() + + ctx, cancel := context.WithTimeout(t.Context(), 60*time.Second) + defer cancel() + + // Register default namespace. + _, err = c.WorkflowService().RegisterNamespace(ctx, &workflowservice.RegisterNamespaceRequest{ + Namespace: "default", + WorkflowExecutionRetentionPeriod: durationpb.New(24 * time.Hour), + }) + require.NoError(t, err) + + // Start workflow. + taskQueue := "test-task-queue" + run, err := c.ExecuteWorkflow(ctx, client.StartWorkflowOptions{TaskQueue: taskQueue}, SimpleWorkflow) + require.NoError(t, err) + + // Check that the workflow task was backlogged (to test task persistence). + adminConn, err := grpc.NewClient(frontendHostPort, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer adminConn.Close() + adminClient := adminservice.NewAdminServiceClient(adminConn) + assert.Eventually(t, func() bool { + response, err := adminClient.GetTaskQueueTasks(ctx, &adminservice.GetTaskQueueTasksRequest{ + Namespace: "default", + TaskQueue: taskQueue, + TaskQueueType: enumspb.TASK_QUEUE_TYPE_WORKFLOW, + BatchSize: 10, + }) + if err != nil { + return false + } + return len(response.Tasks) > 0 + }, 20*time.Second, 100*time.Millisecond) + + // Start worker. + w := worker.New(c, taskQueue, worker.Options{}) + w.RegisterWorkflow(SimpleWorkflow) + err = w.Start() + require.NoError(t, err) + defer w.Stop() + + // Wait for the workflow to complete. + var result string + err = run.Get(ctx, &result) + require.NoError(t, err) + assert.Equal(t, "Hello World", result) + + // Verify workflow history can be retrieved. + iter := c.GetWorkflowHistory(ctx, run.GetID(), run.GetRunID(), false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) + require.True(t, iter.HasNext()) } func loadConfig(t *testing.T) *config.Config { cfg := loadSQLiteConfig(t) setTestPorts(cfg) - return cfg } @@ -89,23 +159,42 @@ func loadSQLiteConfig(t *testing.T) *config.Config { cfg.DynamicConfigClient.Filepath = path.Join(configDir, "dynamicconfig", "development-sql.yaml") + // Use a unique temporary file for each test to avoid conflicts + tmpFile := fmt.Sprintf("/tmp/temporal_test_%d.db", time.Now().UnixNano()) + t.Cleanup(func() { + _ = os.Remove(tmpFile) + }) + for name, store := range cfg.Persistence.DataStores { + store.SQL.DatabaseName = tmpFile + cfg.Persistence.DataStores[name] = store + } + return cfg } -// setTestPorts sets the ports of all services to something different from the default ports, so that we can run the -// tests in parallel. +// setTestPorts sets the ports of all services to something different from the default ports. func setTestPorts(cfg *config.Config) { port := 10000 + // The prometheus reporter does not shut down in-between test runs. + // This will assign a random port to the prometheus reporter, + // so that it doesn't conflict with other tests. + cfg.Global.Metrics.Prometheus.ListenAddress = ":0" + for k, v := range cfg.Services { - rpc := v.RPC - rpc.GRPCPort = port + v.RPC.GRPCPort = port + port++ + + v.RPC.MembershipPort = port port++ - rpc.MembershipPort = port + + v.RPC.HTTPPort = port port++ - v.RPC = rpc + cfg.Services[k] = v } + + cfg.Global.PProf.Port = port } func getFrontendInterceptors() func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { @@ -178,6 +267,7 @@ func (d *errorLogDetector) Error(msg string, tags ...tag.Tag) { "Unable to process new range", "Unable to call", "service failures", + "unable to retrieve tasks", // normal during startup } { if strings.Contains(msg, s) { return @@ -248,3 +338,7 @@ func TestErrorLogDetector(t *testing.T) { d.Fatal("fatal") assert.Empty(t, f.errorLogs, "should not fail the test if the detector is stopped") } + +func SimpleWorkflow(ctx workflow.Context) (string, error) { + return "Hello World", nil +} diff --git a/tests/add_tasks_test.go b/tests/add_tasks_test.go index c6c4ded6f34..40fa79eb2c7 100644 --- a/tests/add_tasks_test.go +++ b/tests/add_tasks_test.go @@ -184,7 +184,7 @@ func (s *AddTasksSuite) TestAddTasks_Ok() { } s.shouldSkip.Store(false) - blob, err := serialization.NewTaskSerializer().SerializeTask(task) + blob, err := serialization.NewTaskSerializer(serialization.NewSerializer()).SerializeTask(task) s.NoError(err) shardID := tasks.GetShardIDForTask(task, int(s.GetTestClusterConfig().HistoryConfig.NumHistoryShards)) request := &adminservice.AddTasksRequest{ diff --git a/tests/gethistory_test.go b/tests/gethistory_test.go index c228f1dfd8b..0476530afef 100644 --- a/tests/gethistory_test.go +++ b/tests/gethistory_test.go @@ -530,12 +530,11 @@ func (s *RawHistorySuite) TestGetWorkflowExecutionHistory_GetRawHistoryData() { return responseInner.RawHistory, responseInner.NextPageToken } - serializer := serialization.NewSerializer() convertBlob := func(blobs []*commonpb.DataBlob) []*historypb.HistoryEvent { events := []*historypb.HistoryEvent{} for _, blob := range blobs { s.True(blob.GetEncodingType() == enumspb.ENCODING_TYPE_PROTO3) - blobEvents, err := serializer.DeserializeEvents(&commonpb.DataBlob{ + blobEvents, err := serialization.DefaultDecoder.DeserializeEvents(&commonpb.DataBlob{ EncodingType: enumspb.ENCODING_TYPE_PROTO3, Data: blob.Data, }) diff --git a/tests/ndc/ndc_test.go b/tests/ndc/ndc_test.go index c1373d37b74..f8eb23f3851 100644 --- a/tests/ndc/ndc_test.go +++ b/tests/ndc/ndc_test.go @@ -2409,7 +2409,7 @@ func (s *NDCFunctionalTestSuite) setupRemoteFrontendClients() { func (s *NDCFunctionalTestSuite) sizeOfHistoryEvents( events []*historypb.HistoryEvent, ) int64 { - blob, err := serialization.NewSerializer().SerializeEvents(events) + blob, err := s.serializer.SerializeEvents(events) s.NoError(err) return int64(len(blob.Data)) } diff --git a/tests/ndc/replication_migration_back_test.go b/tests/ndc/replication_migration_back_test.go index e07f0bda2e7..ce4a2ba75e3 100644 --- a/tests/ndc/replication_migration_back_test.go +++ b/tests/ndc/replication_migration_back_test.go @@ -344,11 +344,10 @@ func (s *ReplicationMigrationBackTestSuite) assertHistoryEvents( Return(s.passiveCluster.AdminClient(), nil). AnyTimes() - serializer := serialization.NewSerializer() passiveClusterFetcher := eventhandler.NewHistoryPaginatedFetcher( nil, mockClientBean, - serializer, + s.serializer, s.logger, ) diff --git a/tests/ndc/replication_task_batching_test.go b/tests/ndc/replication_task_batching_test.go index 16f82c66f2d..8b83e026c84 100644 --- a/tests/ndc/replication_task_batching_test.go +++ b/tests/ndc/replication_task_batching_test.go @@ -169,11 +169,10 @@ func (s *NDCReplicationTaskBatchingTestSuite) assertHistoryEvents( Return(s.passtiveCluster.AdminClient(), nil). AnyTimes() - serializer := serialization.NewSerializer() passiveClusterFetcher := eventhandler.NewHistoryPaginatedFetcher( nil, mockClientBean, - serializer, + s.serializer, s.logger, ) diff --git a/tests/xdc/history_replication_dlq_test.go b/tests/xdc/history_replication_dlq_test.go index 3b8b43788e2..490f36be477 100644 --- a/tests/xdc/history_replication_dlq_test.go +++ b/tests/xdc/history_replication_dlq_test.go @@ -264,8 +264,7 @@ func (s *historyReplicationDLQSuite) TestWorkflowReplicationTaskFailure() { // Wait for the replication task executor to process all the replication tasks for this workflow. // That way, we will know when the DLQ contains everything it needs for this workflow. - serializer := serialization.NewSerializer() - events := s.waitUntilWorkflowReplicated(ctx, serializer, workflowID) + events := s.waitUntilWorkflowReplicated(ctx, workflowID) // Wait until all the replication tasks for this workflow are in the DLQ. // We need to do this because we don't want to start re-enqueuing the DLQ until it contains all the replication @@ -327,7 +326,7 @@ func (s *historyReplicationDLQSuite) TestWorkflowReplicationTaskFailure() { if s.enableTransitionHistory { s.waitUntilWorkflowVerified(ctx, workflowID, events[len(events)-1].GetEventId()) } else { - s.waitUntilWorkflowReplicated(context.Background(), serializer, workflowID) + s.waitUntilWorkflowReplicated(context.Background(), workflowID) } } @@ -391,7 +390,6 @@ func (s *historyReplicationDLQSuite) waitForNSReplication(ctx context.Context, n // It does this by waiting for the replication task executor to process the workflow completion replication event. func (s *historyReplicationDLQSuite) waitUntilWorkflowReplicated( ctx context.Context, - serializer serialization.Serializer, workflowID string, ) []*historypb.HistoryEvent { var historyEvents []*historypb.HistoryEvent @@ -402,7 +400,7 @@ func (s *historyReplicationDLQSuite) waitUntilWorkflowReplicated( if attr.WorkflowId != workflowID { continue } - events, err := serializer.DeserializeEvents(attr.Events) + events, err := serialization.DefaultDecoder.DeserializeEvents(attr.Events) s.NoError(err) historyEvents = append(historyEvents, events...) @@ -419,7 +417,7 @@ func (s *historyReplicationDLQSuite) waitUntilWorkflowReplicated( } completed := false for _, blob := range attr.VersionedTransitionArtifact.EventBatches { - e, err := serializer.DeserializeEvents(blob) + e, err := serialization.DefaultDecoder.DeserializeEvents(blob) s.NoError(err) historyEvents = append(historyEvents, e...) for _, event := range e { diff --git a/tests/xdc/history_replication_signals_and_updates_test.go b/tests/xdc/history_replication_signals_and_updates_test.go index 0bc18c2aaeb..a587383a083 100644 --- a/tests/xdc/history_replication_signals_and_updates_test.go +++ b/tests/xdc/history_replication_signals_and_updates_test.go @@ -683,14 +683,13 @@ func (c *hrsuTestCluster) executeHistoryReplicationTasksUntil( } func (s *hrsuTestSuite) executeHistoryReplicationTask(task *hrsuTestExecutableTask) []*historypb.HistoryEvent { - serializer := serialization.NewSerializer() trackableTask := (*task).TrackableExecutableTask err := trackableTask.Execute() s.NoError(err) task.result <- err attrs := (*task).replicationTask.GetHistoryTaskAttributes() s.NotNil(attrs) - events, err := serializer.DeserializeEvents(attrs.Events) + events, err := serialization.DefaultDecoder.DeserializeEvents(attrs.Events) s.NoError(err) return events } diff --git a/tests/xdc/stream_based_replication_test.go b/tests/xdc/stream_based_replication_test.go index 5735a0a5430..1349d628897 100644 --- a/tests/xdc/stream_based_replication_test.go +++ b/tests/xdc/stream_based_replication_test.go @@ -270,17 +270,16 @@ func (s *streamBasedReplicationTestSuite) assertHistoryEvents( AnyTimes() mockClientBean.EXPECT().GetRemoteAdminClient("cluster2").Return(s.clusters[1].AdminClient(), nil).AnyTimes() - serializer := serialization.NewSerializer() cluster1Fetcher := eventhandler.NewHistoryPaginatedFetcher( nil, mockClientBean, - serializer, + s.serializer, s.logger, ) cluster2Fetcher := eventhandler.NewHistoryPaginatedFetcher( nil, mockClientBean, - serializer, + s.serializer, s.logger, ) iterator1 := cluster1Fetcher.GetSingleWorkflowHistoryPaginatedIteratorExclusive( diff --git a/tools/tdbg/commands.go b/tools/tdbg/commands.go index 0338309ac5a..d7e05760be0 100644 --- a/tools/tdbg/commands.go +++ b/tools/tdbg/commands.go @@ -51,7 +51,6 @@ func AdminShowWorkflow(c *cli.Context, clientFactory ClientFactory) error { outputFileName := c.String(FlagOutputFilename) client := clientFactory.AdminClient(c) - serializer := serialization.NewSerializer() ctx, cancel := newContext(c) @@ -151,7 +150,6 @@ func AdminImportWorkflow(c *cli.Context, clientFactory ClientFactory) error { inputFileName := c.String(FlagInputFilename) client := clientFactory.AdminClient(c) - serializer := serialization.NewSerializer() ctx, cancel := newContext(c) diff --git a/tools/tdbg/task_encoder.go b/tools/tdbg/task_encoder.go index 8b0aebf4dff..1bd93bbbeb4 100644 --- a/tools/tdbg/task_encoder.go +++ b/tools/tdbg/task_encoder.go @@ -68,17 +68,17 @@ func NewPredefinedTaskBlobDeserializer() PredefinedTaskBlobDeserializer { func (d PredefinedTaskBlobDeserializer) Deserialize(categoryID int, blob *commonpb.DataBlob) (proto.Message, error) { switch categoryID { case tasks.CategoryIDTransfer: - return serialization.TransferTaskInfoFromBlob(blob) + return serialization.DefaultDecoder.TransferTaskInfoFromBlob(blob) case tasks.CategoryIDTimer: - return serialization.TimerTaskInfoFromBlob(blob) + return serialization.DefaultDecoder.TimerTaskInfoFromBlob(blob) case tasks.CategoryIDVisibility: - return serialization.VisibilityTaskInfoFromBlob(blob) + return serialization.DefaultDecoder.VisibilityTaskInfoFromBlob(blob) case tasks.CategoryIDReplication: - return serialization.ReplicationTaskInfoFromBlob(blob) + return serialization.DefaultDecoder.ReplicationTaskInfoFromBlob(blob) case tasks.CategoryIDArchival: - return serialization.ArchivalTaskInfoFromBlob(blob) + return serialization.DefaultDecoder.ArchivalTaskInfoFromBlob(blob) case tasks.CategoryIDOutbound: - return serialization.OutboundTaskInfoFromBlob(blob) + return serialization.DefaultDecoder.OutboundTaskInfoFromBlob(blob) default: return nil, fmt.Errorf("unsupported task category %v", categoryID) } diff --git a/tools/tdbg/task_encoder_test.go b/tools/tdbg/task_encoder_test.go index 5fdce3fef10..9424709552e 100644 --- a/tools/tdbg/task_encoder_test.go +++ b/tools/tdbg/task_encoder_test.go @@ -95,11 +95,11 @@ func TestPredefinedTasks(t *testing.T) { &tasks.ArchiveExecutionTask{}, &tasks.StateMachineOutboundTask{}, } - serializer := serialization.NewTaskSerializer() + taskSerializer := serialization.NewTaskSerializer(serialization.NewSerializer()) expectedTaskTypes := make([]string, len(historyTasks)) for i, task := range historyTasks { expectedTaskTypes[i] = enumsspb.TaskType_name[int32(task.GetType())] - blob, err := serializer.SerializeTask(task) + blob, err := taskSerializer.SerializeTask(task) require.NoError(t, err) err = encoder.Encode(&buf, task.GetCategory().ID(), blob) require.NoError(t, err) diff --git a/tools/tdbg/tdbgtest/output_parsing_test.go b/tools/tdbg/tdbgtest/output_parsing_test.go index 772dc8f9308..c7dce1180e3 100644 --- a/tools/tdbg/tdbgtest/output_parsing_test.go +++ b/tools/tdbg/tdbgtest/output_parsing_test.go @@ -39,7 +39,8 @@ func TestParseDLQMessages(t *testing.T) { }, TaskID: 13, } - blob, err := serialization.NewTaskSerializer().SerializeTask(task) + taskSerializer := serialization.NewTaskSerializer(serialization.NewSerializer()) + blob, err := taskSerializer.SerializeTask(task) require.NoError(t, err) client := &testClient{ getDLQTasksFn: func(request *adminservice.GetDLQTasksRequest) (*adminservice.GetDLQTasksResponse, error) {