From 665ad9268efa7ebc447abcd799870f51a3690a97 Mon Sep 17 00:00:00 2001 From: Aleksandr Razumov Date: Sat, 15 Apr 2023 17:39:47 +0300 Subject: [PATCH] fix: structured concurrency --- .golangci.yml | 2 +- telegram/updates/config.go | 3 +++ telegram/updates/internal/e2e/manager_test.go | 2 +- telegram/updates/manager.go | 6 ----- telegram/updates/sequence_box.go | 6 +++++ telegram/updates/state.go | 24 ++++++------------- telegram/updates/state_apply.go | 12 ++++++++++ telegram/updates/state_channel.go | 21 +++++++--------- 8 files changed, 39 insertions(+), 37 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 7c646a5b00..8fbca15310 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -50,7 +50,7 @@ linters: enable: - depguard - dogsled - - dupl +# - dupl - errcheck - gochecknoinits - goconst diff --git a/telegram/updates/config.go b/telegram/updates/config.go index 9df74a8bc6..820ee76e81 100644 --- a/telegram/updates/config.go +++ b/telegram/updates/config.go @@ -50,6 +50,9 @@ func (cfg *Config) setDefaults() { if cfg.TracerProvider == nil { cfg.TracerProvider = trace.NewNoopTracerProvider() } + if cfg.Storage == nil { + cfg.Storage = newMemStorage() + } if cfg.OnChannelTooLong == nil { cfg.OnChannelTooLong = func(channelID int64) { cfg.Logger.Error("Difference too long", zap.Int64("channel_id", channelID)) diff --git a/telegram/updates/internal/e2e/manager_test.go b/telegram/updates/internal/e2e/manager_test.go index 4062d46bcb..afe63a4968 100644 --- a/telegram/updates/internal/e2e/manager_test.go +++ b/telegram/updates/internal/e2e/manager_test.go @@ -181,7 +181,7 @@ func testManager(t *testing.T, f func(s *server, storage updates.StateStorage) c }) t.Log("Waiting for shutdown") - require.NoError(t, g.Wait()) + require.ErrorIs(t, g.Wait(), context.Canceled) t.Log("Checking") require.Equal(t, s.messages, h.messages) diff --git a/telegram/updates/manager.go b/telegram/updates/manager.go index 824f4ff885..1db8d63e4d 100644 --- a/telegram/updates/manager.go +++ b/telegram/updates/manager.go @@ -157,12 +157,6 @@ func (m *Manager) Run(ctx context.Context, api API, userID int64, opt AuthOption wg.Go(func() error { return m.state.Run(ctx) }) - wg.Go(func() error { - <-ctx.Done() - lg.Debug("Stopping") - m.state.Stop() - return nil - }) lg.Debug("Wait") return wg.Wait() } diff --git a/telegram/updates/sequence_box.go b/telegram/updates/sequence_box.go index 69830e724f..1e0caecf82 100644 --- a/telegram/updates/sequence_box.go +++ b/telegram/updates/sequence_box.go @@ -52,6 +52,9 @@ func newSequenceBox(cfg sequenceConfig) *sequenceBox { } func (s *sequenceBox) Handle(ctx context.Context, u update) error { + ctx, span := s.tracer.Start(ctx, "sequenceBox.Handle") + defer span.End() + log := s.log.With(zap.Int("upd_from", u.start()), zap.Int("upd_to", u.end())) if checkGap(s.state, u.State, u.Count) == gapIgnore { log.Debug("Outdated update, skipping", zap.Int("internalState", s.state)) @@ -110,6 +113,9 @@ func (s *sequenceBox) Handle(ctx context.Context, u update) error { } func (s *sequenceBox) applyPending(ctx context.Context) error { + ctx, span := s.tracer.Start(ctx, "sequenceBox.applyPending") + defer span.End() + sort.SliceStable(s.pending, func(i, j int) bool { return s.pending[i].start() < s.pending[j].start() }) diff --git a/telegram/updates/state.go b/telegram/updates/state.go index 55b993c9b1..2332c6391e 100644 --- a/telegram/updates/state.go +++ b/telegram/updates/state.go @@ -117,8 +117,7 @@ func newState(ctx context.Context, cfg stateConfig) *internalState { state := s.newChannelState(id, info.AccessHash, info.Pts) s.channels[id] = state s.wg.Go(func() error { - state.Run(ctx) - return nil + return state.Run(ctx) }) } @@ -138,13 +137,6 @@ func (s *internalState) Push(ctx context.Context, u tg.UpdatesClass) error { } } -func (s *internalState) Stop() { - close(s.externalQueue) - for _, c := range s.channels { - c.Stop() - } -} - func (s *internalState) Run(ctx context.Context) error { s.log.Debug("Starting updates handler") defer s.log.Debug("Updates handler stopped") @@ -152,13 +144,12 @@ func (s *internalState) Run(ctx context.Context) error { for { select { - case u, ok := <-s.externalQueue: - if !ok { - if len(s.pts.pending) > 0 || len(s.qts.pending) > 0 || len(s.seq.pending) > 0 { - s.getDifferenceLogger(ctx) - } - return nil + case <-ctx.Done(): + if len(s.pts.pending) > 0 || len(s.qts.pending) > 0 || len(s.seq.pending) > 0 { + s.getDifferenceLogger(ctx) } + return ctx.Err() + case u := <-s.externalQueue: ctx := trace.ContextWithSpanContext(ctx, u.span) if err := s.handleUpdates(ctx, u.update); err != nil { s.log.Error("Handle updates error", zap.Error(err)) @@ -328,8 +319,7 @@ func (s *internalState) handleChannel(ctx context.Context, channelID int64, date state = s.newChannelState(channelID, accessHash, localPts) s.channels[channelID] = state s.wg.Go(func() error { - state.Run(ctx) - return nil + return state.Run(ctx) }) } diff --git a/telegram/updates/state_apply.go b/telegram/updates/state_apply.go index 77e993cee2..4bf0875064 100644 --- a/telegram/updates/state_apply.go +++ b/telegram/updates/state_apply.go @@ -3,6 +3,7 @@ package updates import ( "context" + "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "github.com/gotd/td/tg" @@ -33,6 +34,9 @@ func (s *internalState) applySeq(ctx context.Context, state int, updates []updat } func (s *internalState) applyCombined(ctx context.Context, comb *tg.UpdatesCombined) (ptsChanged bool, err error) { + ctx, span := s.tracer.Start(ctx, "internalState.applyCombined") + defer span.End() + var ( ents = entities{ Users: comb.Users, @@ -56,6 +60,7 @@ func (s *internalState) applyCombined(ctx context.Context, comb *tg.UpdatesCombi if err := st.Push(ctx, channelUpdate{ update: u, entities: ents, + span: trace.SpanContextFromContext(ctx), }); err != nil { s.log.Error("Push channel update error", zap.Error(err)) } @@ -78,6 +83,7 @@ func (s *internalState) applyCombined(ctx context.Context, comb *tg.UpdatesCombi if err := s.handleChannel(ctx, channelID, comb.Date, pts, ptsCount, channelUpdate{ update: u, entities: ents, + span: trace.SpanContextFromContext(ctx), }); err != nil { s.log.Error("Handle channel update error", zap.Error(err)) } @@ -131,6 +137,9 @@ func (s *internalState) applyCombined(ctx context.Context, comb *tg.UpdatesCombi } func (s *internalState) applyPts(ctx context.Context, state int, updates []update) error { + ctx, span := s.tracer.Start(ctx, "internalState.applyPts") + defer span.End() + var ( converted []tg.UpdateClass ents entities @@ -157,6 +166,9 @@ func (s *internalState) applyPts(ctx context.Context, state int, updates []updat } func (s *internalState) applyQts(ctx context.Context, state int, updates []update) error { + ctx, span := s.tracer.Start(ctx, "internalState.applyQts") + defer span.End() + var ( converted []tg.UpdateClass ents entities diff --git a/telegram/updates/state_channel.go b/telegram/updates/state_channel.go index 7e505fdf93..ab9815ebe1 100644 --- a/telegram/updates/state_channel.go +++ b/telegram/updates/state_channel.go @@ -95,11 +95,7 @@ func (s *channelState) Push(ctx context.Context, u channelUpdate) error { } } -func (s *channelState) Stop() { - close(s.updates) -} - -func (s *channelState) Run(ctx context.Context) { +func (s *channelState) Run(ctx context.Context) error { // Subscribe to channel updates. if err := s.getDifference(ctx); err != nil { s.log.Error("Failed to subscribe to channel updates", zap.Error(err)) @@ -107,19 +103,20 @@ func (s *channelState) Run(ctx context.Context) { for { select { - case u, ok := <-s.updates: - if !ok { - if len(s.pts.pending) > 0 { - s.getDifferenceLogger(ctx) - } - return - } + case u := <-s.updates: + ctx := trace.ContextWithSpanContext(ctx, u.span) if err := s.handleUpdate(ctx, u.update, u.entities); err != nil { s.log.Error("Handle update error", zap.Error(err)) } case <-s.pts.gapTimeout.C: s.log.Debug("Gap timeout") s.getDifferenceLogger(ctx) + case <-ctx.Done(): + if len(s.pts.pending) > 0 { + // This will probably fail. + s.getDifferenceLogger(ctx) + } + return ctx.Err() case <-s.idleTimeout.C: s.log.Debug("Idle timeout") s.resetIdleTimer()