From 8803efbcefd84c50c74d16cf14a8fce68f674e7e Mon Sep 17 00:00:00 2001 From: Zohaib Date: Sat, 19 Aug 2023 09:22:05 -0700 Subject: [PATCH] Fixing datetime replication Turns out CBOR by default serializes date time into unix timestamp and drops the serialization information about it. This PR adds new struct and enforces tag to be serialized and deserialized along with the values. --- core/events.go | 10 +++++ db/change_log.go | 2 +- db/change_log_event.go | 80 +++++++++++++++++++++++++++++++--- logstream/replication_event.go | 43 +++++++++++++++--- logstream/replicator.go | 6 +-- marmot.go | 4 +- 6 files changed, 128 insertions(+), 17 deletions(-) create mode 100644 core/events.go diff --git a/core/events.go b/core/events.go new file mode 100644 index 0000000..baf8d26 --- /dev/null +++ b/core/events.go @@ -0,0 +1,10 @@ +package core + +import "github.com/fxamacker/cbor/v2" + +var CBORTags = cbor.NewTagSet() + +type ReplicableEvent[T any] interface { + Wrap() (T, error) + Unwrap() (T, error) +} diff --git a/db/change_log.go b/db/change_log.go index 786f5f0..0815ecb 100644 --- a/db/change_log.go +++ b/db/change_log.go @@ -368,7 +368,7 @@ func (conn *SqliteStreamDB) publishChangeLog() { err = conn.consumeChangeLogs(change.TableName, []*changeLogEntry{&logEntry}) if err != nil { - if err == ErrLogNotReadyToPublish || err == context.Canceled { + if errors.Is(err, ErrLogNotReadyToPublish) || errors.Is(err, context.Canceled) { break } diff --git a/db/change_log_event.go b/db/change_log_event.go index d594ca1..35c23a4 100644 --- a/db/change_log_event.go +++ b/db/change_log_event.go @@ -2,15 +2,23 @@ package db import ( "hash/fnv" + "reflect" "sort" "sync" + "time" "github.com/fxamacker/cbor/v2" + "github.com/maxpert/marmot/core" + "github.com/rs/zerolog/log" ) var tablePKColumnsCache = make(map[string][]string) var tablePKColumnsLock = sync.RWMutex{} +type sensitiveTypeWrapper struct { + Time *time.Time `cbor:"1,keyasint,omitempty"` +} + type ChangeLogEvent struct { Id int64 Type string @@ -19,15 +27,50 @@ type ChangeLogEvent struct { tableInfo []*ColumnInfo `cbor:"-"` } -func (e *ChangeLogEvent) Marshal() ([]byte, error) { - return cbor.Marshal(e) +func init() { + err := core.CBORTags.Add( + cbor.TagOptions{ + DecTag: cbor.DecTagRequired, + EncTag: cbor.EncTagRequired, + }, + reflect.TypeOf(sensitiveTypeWrapper{}), + 32, + ) + + log.Panic().Err(err) +} + +func (s sensitiveTypeWrapper) GetValue() any { + // Right now only sensitive value is Time + return s.Time +} + +func (e ChangeLogEvent) Wrap() (ChangeLogEvent, error) { + return e.prepare(), nil } -func (e *ChangeLogEvent) Unmarshal(data []byte) error { - return cbor.Unmarshal(data, e) +func (e ChangeLogEvent) Unwrap() (ChangeLogEvent, error) { + ret := ChangeLogEvent{ + Id: e.Id, + TableName: e.TableName, + Type: e.Type, + Row: map[string]any{}, + tableInfo: e.tableInfo, + } + + for k, v := range e.Row { + if st, ok := v.(sensitiveTypeWrapper); ok { + ret.Row[k] = st.GetValue() + continue + } + + ret.Row[k] = v + } + + return ret, nil } -func (e *ChangeLogEvent) Hash() (uint64, error) { +func (e ChangeLogEvent) Hash() (uint64, error) { hasher := fnv.New64() enc := cbor.NewEncoder(hasher) err := enc.StartIndefiniteArray() @@ -56,7 +99,7 @@ func (e *ChangeLogEvent) Hash() (uint64, error) { return hasher.Sum64(), nil } -func (e *ChangeLogEvent) getSortedPKColumns() []string { +func (e ChangeLogEvent) getSortedPKColumns() []string { tablePKColumnsLock.RLock() if values, found := tablePKColumnsCache[e.TableName]; found { @@ -79,3 +122,28 @@ func (e *ChangeLogEvent) getSortedPKColumns() []string { tablePKColumnsCache[e.TableName] = pkColumns return pkColumns } + +func (e ChangeLogEvent) prepare() ChangeLogEvent { + needsTransform := false + preparedRow := map[string]any{} + for k, v := range e.Row { + if t, ok := v.(time.Time); ok { + preparedRow[k] = sensitiveTypeWrapper{Time: &t} + needsTransform = true + } else { + preparedRow[k] = v + } + } + + if !needsTransform { + return e + } + + return ChangeLogEvent{ + Id: e.Id, + Type: e.Type, + TableName: e.TableName, + Row: preparedRow, + tableInfo: e.tableInfo, + } +} diff --git a/logstream/replication_event.go b/logstream/replication_event.go index ded01f2..e7d1c9f 100644 --- a/logstream/replication_event.go +++ b/logstream/replication_event.go @@ -1,16 +1,49 @@ package logstream -import "github.com/fxamacker/cbor/v2" +import ( + "github.com/fxamacker/cbor/v2" + "github.com/maxpert/marmot/core" +) -type ReplicationEvent[T any] struct { +type ReplicationEvent[T core.ReplicableEvent[T]] struct { FromNodeId uint64 - Payload *T + Payload T } func (e *ReplicationEvent[T]) Marshal() ([]byte, error) { - return cbor.Marshal(e) + wrappedPayload, err := e.Payload.Wrap() + if err != nil { + return nil, err + } + + ev := ReplicationEvent[T]{ + FromNodeId: e.FromNodeId, + Payload: wrappedPayload, + } + + em, err := cbor.EncOptions{}.EncModeWithTags(core.CBORTags) + if err != nil { + return nil, err + } + + return em.Marshal(ev) } func (e *ReplicationEvent[T]) Unmarshal(data []byte) error { - return cbor.Unmarshal(data, e) + dm, err := cbor.DecOptions{}.DecModeWithTags(core.CBORTags) + if err != nil { + return nil + } + + err = dm.Unmarshal(data, e) + if err != nil { + return err + } + + e.Payload, err = e.Payload.Unwrap() + if err != nil { + return err + } + + return nil } diff --git a/logstream/replicator.go b/logstream/replicator.go index 2ca2329..e6faa4d 100644 --- a/logstream/replicator.go +++ b/logstream/replicator.go @@ -2,6 +2,7 @@ package logstream import ( "context" + "errors" "fmt" "time" @@ -178,8 +179,7 @@ func (r *Replicator) Listen(shardID uint64, callback func(payload []byte) error) savedSeq := r.repState.get(streamName(shardID, r.compressionEnabled)) for sub.IsValid() { msg, err := sub.NextMsg(5 * time.Second) - - if err == nats.ErrTimeout { + if errors.Is(err, nats.ErrTimeout) { continue } @@ -199,7 +199,7 @@ func (r *Replicator) Listen(shardID uint64, callback func(payload []byte) error) err = r.invokeListener(callback, msg) if err != nil { msg.Nak() - if err == context.Canceled { + if errors.Is(err, context.Canceled) { return nil } diff --git a/marmot.go b/marmot.go index 918ddef..f39a2e7 100644 --- a/marmot.go +++ b/marmot.go @@ -189,7 +189,7 @@ func onChangeEvent(streamDB *db.SqliteStreamDB, ctxSt *utils.StateContext, event return err } - return streamDB.Replicate(ev.Payload) + return streamDB.Replicate(&ev.Payload) } } @@ -206,7 +206,7 @@ func onTableChanged(r *logstream.Replicator, ctxSt *utils.StateContext, events E ev := &logstream.ReplicationEvent[db.ChangeLogEvent]{ FromNodeId: nodeID, - Payload: event, + Payload: *event, } data, err := ev.Marshal()