Skip to content

Commit 5b34fb2

Browse files
fix: fix a race condition in the tracker
This commit fixes a race condition in the push tracker where both the p.done and p.err channels could have a value at the same time. This could cause the distributor to return success when it should instead have returned an error.
1 parent 849a921 commit 5b34fb2

File tree

3 files changed

+253
-41
lines changed

3 files changed

+253
-41
lines changed

pkg/distributor/distributor.go

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -454,29 +454,6 @@ type streamTracker struct {
454454
failed atomic.Int32
455455
}
456456

457-
// TODO taken from Cortex, see if we can refactor out an usable interface.
458-
type pushTracker struct {
459-
streamsPending atomic.Int32
460-
streamsFailed atomic.Int32
461-
done chan struct{}
462-
err chan error
463-
}
464-
465-
// doneWithResult records the result of a stream push.
466-
// If err is nil, the stream push is considered successful.
467-
// If err is not nil, the stream push is considered failed.
468-
func (p *pushTracker) doneWithResult(err error) {
469-
if err == nil {
470-
if p.streamsPending.Dec() == 0 {
471-
p.done <- struct{}{}
472-
}
473-
} else {
474-
if p.streamsFailed.Inc() == 1 {
475-
p.err <- err
476-
}
477-
}
478-
}
479-
480457
func (d *Distributor) waitSimulatedLatency(ctx context.Context, tenantID string, start time.Time) {
481458
latency := d.validator.SimulatedPushLatency(tenantID)
482459
if latency > 0 {
@@ -754,10 +731,7 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe
754731
const maxExpectedReplicationSet = 5 // typical replication factor 3 plus one for inactive plus one for luck
755732
var descs [maxExpectedReplicationSet]ring.InstanceDesc
756733

757-
tracker := pushTracker{
758-
done: make(chan struct{}, 1), // buffer avoids blocking if caller terminates - sendSamples() only sends once on each
759-
err: make(chan error, 1),
760-
}
734+
tracker := newBasicPushTracker()
761735
streamsToWrite := 0
762736
if d.cfg.IngesterEnabled {
763737
streamsToWrite += len(streams)
@@ -766,15 +740,15 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe
766740
streamsToWrite += len(streams)
767741
}
768742
// We must correctly set streamsPending before beginning any writes to ensure we don't have a race between finishing all of one path before starting the other.
769-
tracker.streamsPending.Store(int32(streamsToWrite))
743+
tracker.Add(int32(streamsToWrite))
770744

771745
if d.cfg.KafkaEnabled {
772746
subring, err := d.partitionRing.PartitionRing().ShuffleShard(tenantID, d.validator.IngestionPartitionsTenantShardSize(tenantID))
773747
if err != nil {
774748
return nil, err
775749
}
776750
// We don't need to create a new context like the ingester writes, because we don't return unless all writes have succeeded.
777-
d.sendStreamsToKafka(ctx, streams, tenantID, &tracker, subring)
751+
d.sendStreamsToKafka(ctx, streams, tenantID, tracker, subring)
778752
}
779753

780754
if d.cfg.IngesterEnabled {
@@ -823,7 +797,7 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe
823797
case d.ingesterTasks <- pushIngesterTask{
824798
ingester: ingester,
825799
streamTracker: samples,
826-
pushTracker: &tracker,
800+
pushTracker: tracker,
827801
ctx: localCtx,
828802
cancel: cancel,
829803
}:
@@ -833,14 +807,11 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe
833807
}
834808
}
835809

836-
select {
837-
case err := <-tracker.err:
810+
if err := tracker.Wait(ctx); err != nil {
838811
return nil, err
839-
case <-tracker.done:
840-
return &logproto.PushResponse{}, validationErr
841-
case <-ctx.Done():
842-
return nil, ctx.Err()
843812
}
813+
814+
return &logproto.PushResponse{}, nil
844815
}
845816

846817
// missingEnforcedLabels returns true if the stream is missing any of the required labels.
@@ -1135,7 +1106,7 @@ func (d *Distributor) truncateLines(vContext validationContext, stream *logproto
11351106

11361107
type pushIngesterTask struct {
11371108
streamTracker []*streamTracker
1138-
pushTracker *pushTracker
1109+
pushTracker PushTracker
11391110
ingester ring.InstanceDesc
11401111
ctx context.Context
11411112
cancel context.CancelFunc
@@ -1172,12 +1143,12 @@ func (d *Distributor) sendStreams(task pushIngesterTask) {
11721143
if task.streamTracker[i].failed.Inc() <= int32(task.streamTracker[i].maxFailures) {
11731144
continue
11741145
}
1175-
task.pushTracker.doneWithResult(err)
1146+
task.pushTracker.Done(err)
11761147
} else {
11771148
if task.streamTracker[i].succeeded.Inc() != int32(task.streamTracker[i].minSuccess) {
11781149
continue
11791150
}
1180-
task.pushTracker.doneWithResult(nil)
1151+
task.pushTracker.Done(nil)
11811152
}
11821153
}
11831154
}
@@ -1209,14 +1180,14 @@ func (d *Distributor) sendStreamsErr(ctx context.Context, ingester ring.Instance
12091180
return err
12101181
}
12111182

1212-
func (d *Distributor) sendStreamsToKafka(ctx context.Context, streams []KeyedStream, tenant string, tracker *pushTracker, subring *ring.PartitionRing) {
1183+
func (d *Distributor) sendStreamsToKafka(ctx context.Context, streams []KeyedStream, tenant string, tracker PushTracker, subring *ring.PartitionRing) {
12131184
for _, s := range streams {
12141185
go func(s KeyedStream) {
12151186
err := d.sendStreamToKafka(ctx, s, tenant, subring)
12161187
if err != nil {
12171188
err = fmt.Errorf("failed to write stream to kafka: %w", err)
12181189
}
1219-
tracker.doneWithResult(err)
1190+
tracker.Done(err)
12201191
}(s)
12211192
}
12221193
}

pkg/distributor/tracker.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package distributor
2+
3+
import (
4+
"context"
5+
"sync"
6+
)
7+
8+
// PushTracker is an interface to track the status of pushes and wait on
9+
// their completion.
10+
type PushTracker interface {
11+
// Add increments the number of pushes. It must not be called after the
12+
// last call to [Done] has completed (otherwise it can panic).
13+
Add(int32)
14+
15+
// Done decrements the number of pushes. It accepts an optional error
16+
// if the push failed.
17+
Done(err error)
18+
19+
// Wait until all pushes are done a push fails, whichever happens first.
20+
Wait(ctx context.Context) error
21+
}
22+
23+
type basicPushTracker struct {
24+
mtx sync.Mutex // protects the fields below.
25+
n int32 // the number of pushes.
26+
firstErr error // the first reported error from a push.
27+
doneCh chan struct{} // closed when all pushes are done.
28+
errCh chan struct{} // closed when an error is reported.
29+
done bool // fast path, equivalent to select { case <-t.doneCh: default: }
30+
}
31+
32+
// newBasicPushTracker returns a new, initialized [newSimplePushTracker].
33+
func newBasicPushTracker() *basicPushTracker {
34+
return &basicPushTracker{
35+
doneCh: make(chan struct{}),
36+
errCh: make(chan struct{}),
37+
}
38+
}
39+
40+
// Add implements the [PushTracker] interface.
41+
func (t *basicPushTracker) Add(n int32) {
42+
t.mtx.Lock()
43+
defer t.mtx.Unlock()
44+
if t.done {
45+
panic("Add called after last call to Done")
46+
}
47+
t.n += n
48+
if t.n < 0 {
49+
panic("Negative counter")
50+
}
51+
}
52+
53+
// Done implements the [PushTracker] interface.
54+
func (t *basicPushTracker) Done(err error) {
55+
t.mtx.Lock()
56+
defer t.mtx.Unlock()
57+
if t.n <= 0 {
58+
// We panic just like [sync.WaitGroup].
59+
panic("Done called more times than Add")
60+
}
61+
if err != nil && t.firstErr == nil {
62+
// errCh can never be closed twice as t.firstErr can never be nil
63+
// more than once.
64+
t.firstErr = err
65+
close(t.errCh)
66+
}
67+
t.n--
68+
if t.n == 0 {
69+
close(t.doneCh)
70+
t.done = true
71+
}
72+
}
73+
74+
// Wait implements the [PushTracker] interface.
75+
func (t *basicPushTracker) Wait(ctx context.Context) error {
76+
t.mtx.Lock()
77+
// We need to have the mutex here as t.n can be modified as doneCh has
78+
// not been closed, while t.firstErr can still be modified as neither
79+
// doneCh nor errCh have been closed.
80+
if t.firstErr != nil || t.n == 0 {
81+
// We need to store the firstErr before releasing the mutex for the
82+
// same reason.
83+
res := t.firstErr
84+
t.mtx.Unlock()
85+
return res
86+
}
87+
t.mtx.Unlock()
88+
select {
89+
case <-ctx.Done():
90+
return ctx.Err()
91+
case <-t.doneCh:
92+
// Must return t.firstErr as done is also closed if the last push
93+
// failed. We don't need the mutex here as t.firstErr is never
94+
// modified after doneCh is closed.
95+
return t.firstErr
96+
case <-t.errCh:
97+
// We don't need the mutex here either as t.firstErr is never modified
98+
// after errCh is closed.
99+
return t.firstErr
100+
}
101+
}

pkg/distributor/tracker_test.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package distributor
2+
3+
import (
4+
"context"
5+
"errors"
6+
"math/rand"
7+
"sync"
8+
"testing"
9+
"time"
10+
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestBasicPushTracker(t *testing.T) {
15+
t.Run("a new tracker that has never been incremented should never block", func(t *testing.T) {
16+
tracker := newBasicPushTracker()
17+
ctx, cancel := context.WithTimeout(t.Context(), time.Second)
18+
t.Cleanup(cancel)
19+
require.NoError(t, tracker.Wait(ctx))
20+
})
21+
22+
t.Run("a canceled context should return a context canceled error", func(t *testing.T) {
23+
tracker := newBasicPushTracker()
24+
tracker.Add(1)
25+
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond)
26+
t.Cleanup(cancel)
27+
require.EqualError(t, tracker.Wait(ctx), "context deadline exceeded")
28+
})
29+
30+
t.Run("a done tracker with no errors should return nil", func(t *testing.T) {
31+
tracker := newBasicPushTracker()
32+
tracker.Add(1)
33+
tracker.Done(nil)
34+
ctx, cancel := context.WithTimeout(t.Context(), time.Second)
35+
t.Cleanup(cancel)
36+
require.NoError(t, tracker.Wait(ctx))
37+
})
38+
39+
t.Run("a done tracker with an error should return the error", func(t *testing.T) {
40+
tracker := newBasicPushTracker()
41+
tracker.Add(1)
42+
tracker.Done(errors.New("an error occurred"))
43+
ctx, cancel := context.WithTimeout(t.Context(), time.Second)
44+
t.Cleanup(cancel)
45+
require.EqualError(t, tracker.Wait(ctx), "an error occurred")
46+
})
47+
48+
t.Run("a done tracker should return the first error that occurred", func(t *testing.T) {
49+
tracker := newBasicPushTracker()
50+
tracker.Add(2)
51+
tracker.Done(errors.New("an error occurred"))
52+
tracker.Done(errors.New("another error occurred"))
53+
ctx, cancel := context.WithTimeout(t.Context(), time.Second)
54+
t.Cleanup(cancel)
55+
require.EqualError(t, tracker.Wait(ctx), "an error occurred")
56+
})
57+
58+
t.Run("a done tracker should return at least one error", func(t *testing.T) {
59+
t1 := newBasicPushTracker()
60+
t1.Add(2)
61+
t1.Done(nil)
62+
t1.Done(errors.New("an error occurred"))
63+
ctx, cancel := context.WithTimeout(t.Context(), time.Second)
64+
t.Cleanup(cancel)
65+
require.EqualError(t, t1.Wait(ctx), "an error occurred")
66+
// And now test the opposite sequence.
67+
t2 := newBasicPushTracker()
68+
t2.Add(2)
69+
t2.Done(errors.New("an error occurred"))
70+
t2.Done(nil)
71+
ctx, cancel = context.WithTimeout(t.Context(), time.Second)
72+
t.Cleanup(cancel)
73+
require.EqualError(t, t2.Wait(ctx), "an error occurred")
74+
})
75+
76+
t.Run("more Done than Add should panic", func(t *testing.T) {
77+
// Should panic if Done is called before Add.
78+
require.PanicsWithValue(t, "Done called more times than Add", func() {
79+
tracker := newBasicPushTracker()
80+
tracker.Done(nil)
81+
})
82+
// Should panic if Done is called more times than Add.
83+
require.PanicsWithValue(t, "Done called more times than Add", func() {
84+
tracker := newBasicPushTracker()
85+
tracker.Add(1)
86+
tracker.Done(nil)
87+
tracker.Done(nil)
88+
})
89+
})
90+
91+
t.Run("Add after Done should panic", func(t *testing.T) {
92+
require.PanicsWithValue(t, "Add called after last call to Done", func() {
93+
tracker := newBasicPushTracker()
94+
tracker.Add(1)
95+
tracker.Done(nil)
96+
tracker.Add(1)
97+
})
98+
})
99+
100+
t.Run("Negative counter should panic", func(t *testing.T) {
101+
require.PanicsWithValue(t, "Negative counter", func() {
102+
tracker := newBasicPushTracker()
103+
tracker.Add(-1)
104+
})
105+
})
106+
}
107+
108+
// Run with go test -fuzz=FuzzBasicPushTracker.
109+
func FuzzBasicPushTracker(f *testing.F) {
110+
rand.Seed(time.Now().UnixNano())
111+
f.Add(uint16(100))
112+
f.Fuzz(func(t *testing.T, n uint16) {
113+
wg := sync.WaitGroup{}
114+
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
115+
t.Cleanup(cancel)
116+
tracker := newBasicPushTracker()
117+
tracker.Add(int32(n))
118+
// Create a random number of waiters.
119+
for i := 0; i < rand.Intn(100); i++ {
120+
wg.Add(1)
121+
go func() {
122+
defer wg.Done()
123+
// Sleep a random time up to 100ms.
124+
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
125+
require.NoError(t, tracker.Wait(ctx))
126+
}()
127+
}
128+
// Done should be called for each n, cannot be random.
129+
for i := 0; i < int(n); i++ {
130+
wg.Add(1)
131+
go func() {
132+
defer wg.Done()
133+
// Sleep a random time up to 100ms too.
134+
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
135+
tracker.Done(nil)
136+
}()
137+
}
138+
wg.Wait()
139+
})
140+
}

0 commit comments

Comments
 (0)