diff --git a/pubsub/batcher/batcher.go b/pubsub/batcher/batcher.go index 917cef822a..bbcc3de266 100644 --- a/pubsub/batcher/batcher.go +++ b/pubsub/batcher/batcher.go @@ -23,6 +23,8 @@ import ( "errors" "reflect" "sync" + "sync/atomic" + "time" ) // Split determines how to split n (representing n items) into batches based on @@ -72,6 +74,9 @@ type Batcher struct { pending []waiter // items waiting to be handled nHandlers int // number of currently running handler goroutines shutdown bool + + batchSizeTimeout time.Time // the time that len(pending) < opts.MinBatchSize, or zero time + batchTimeoutRunning int32 // atomic counter checking whether timeout wait is running } // Message is larger than the maximum batch byte size @@ -96,6 +101,9 @@ type Options struct { MaxBatchSize int // Maximum bytesize of a batch. 0 means no limit. MaxBatchByteSize int + // BatchTimeout the maximum time a batch can exist under MinBatchSize + // before being sent anyway. + BatchTimeout time.Duration } // newOptionsWithDefaults returns Options with defaults applied to opts. @@ -201,12 +209,30 @@ func (b *Batcher) AddNoWait(item interface{}) <-chan error { // If we can start a handler, do so with the item just added and any others that are pending. batch := b.nextBatch() if batch != nil { - b.wg.Add(1) - go func() { - b.callHandler(batch) - b.wg.Done() - }() - b.nHandlers++ + b.handleBatch(batch) + } + + if batch == nil && len(b.pending) > 0 && b.opts.BatchTimeout > 0 { + // Ensure that we send the batch after the given timeout. Only one + // concurrent process can run this goroutine, ensuring that we don't + // duplicate work. + if atomic.CompareAndSwapInt32(&b.batchTimeoutRunning, 0, 1) { + // If the batch size timeout is zero, this is one of the first items to + // be added to the batch under the minimum batch size. Record when this + // happens so that .nextBatch() can grab the batch on timeout. + if b.batchSizeTimeout.IsZero() { + b.batchSizeTimeout = time.Now() + } + + go func() { + <-time.After(b.opts.BatchTimeout) + b.batchTimeoutRunning = 0 + batch = b.nextBatch() + if batch != nil { + b.handleBatch(batch) + } + }() + } } } // If we can't start a handler, then one of the currently running handlers will @@ -214,14 +240,32 @@ func (b *Batcher) AddNoWait(item interface{}) <-chan error { return c } +func (b *Batcher) handleBatch(batch []waiter) { + if batch == nil || len(batch) == 0 { + return + } + + b.wg.Add(1) + go func() { + b.callHandler(batch) + b.wg.Done() + }() + b.nHandlers++ +} + // nextBatch returns the batch to process, and updates b.pending. // It returns nil if there's no batch ready for processing. // b.mu must be held. func (b *Batcher) nextBatch() []waiter { - if len(b.pending) < b.opts.MinBatchSize { + if len(b.pending) < b.opts.MinBatchSize && b.respectMinBatchSize() { return nil } + if len(b.pending) < b.opts.MinBatchSize { + // reset the timeout counter to zero time + b.batchSizeTimeout = time.Time{} + } + if b.opts.MaxBatchByteSize == 0 && (b.opts.MaxBatchSize == 0 || len(b.pending) <= b.opts.MaxBatchSize) { // Send it all! batch := b.pending @@ -250,6 +294,25 @@ func (b *Batcher) nextBatch() []waiter { return batch } +func (b *Batcher) respectMinBatchSize() bool { + // We handle minimum batch sizes depending on specific + // situations. + if b.shutdown { + // If we're shutting down, do not respect minimums. This takes priority. + return false + } + if b.opts.BatchTimeout > 0 { + // If we have a maximum wait before sending batches below the minimum, and we've + // waited longer than that period, do not respect minimum batches and send! + if !b.batchSizeTimeout.IsZero() && time.Since(b.batchSizeTimeout) >= b.opts.BatchTimeout { + return false + } + } + // At this point, either we're not shutting down and we're not forcing a batch + // due to timeouts. Respect the batch size. + return true +} + func (b *Batcher) callHandler(batch []waiter) { for batch != nil { @@ -283,5 +346,13 @@ func (b *Batcher) Shutdown() { b.mu.Lock() b.shutdown = true b.mu.Unlock() + + // On shutdown, ensure that we attempt to flush any pending items + // if there's a minimum batch size. + if b.nHandlers < b.opts.MaxHandlers { + batch := b.nextBatch() + b.handleBatch(batch) + } + b.wg.Wait() } diff --git a/pubsub/batcher/batcher_test.go b/pubsub/batcher/batcher_test.go index e7c5dd96c1..3fe3fff1ed 100644 --- a/pubsub/batcher/batcher_test.go +++ b/pubsub/batcher/batcher_test.go @@ -171,6 +171,63 @@ func TestMinBatchSize(t *testing.T) { } } +// TestMinBatchSizeFlushesAfterTimeout ensures that Shutdown() flushes batches, even if +// the pending count is less than the minimum batch size. +func TestMinBatchSizeFlushesAfterTimeout(t *testing.T) { + var got [][]int + + batchSize := 3 + opts := &batcher.Options{MinBatchSize: batchSize, BatchTimeout: 10 * time.Millisecond} + + b := batcher.New(reflect.TypeOf(int(0)), opts, func(items interface{}) error { + got = append(got, items.([]int)) + return nil + }) + for i := 0; i < (batchSize - 1); i++ { + b.AddNoWait(i) + } + + // Ensure that we've received nothing + if len(got) > 0 { + t.Errorf("got batch unexpectedly: %+v", got) + } + + <-time.After(opts.BatchTimeout + 5*time.Millisecond) + + want := [][]int{{0, 1}} + if !cmp.Equal(got, want) { + t.Errorf("got %+v, want %+v after timeout", got, want) + } +} + +// TestMinBatchSizeFlushesOnShutdown ensures that Shutdown() flushes batches, even if +// the pending count is less than the minimum batch size. +func TestMinBatchSizeFlushesOnShutdown(t *testing.T) { + var got [][]int + + batchSize := 3 + + b := batcher.New(reflect.TypeOf(int(0)), &batcher.Options{MinBatchSize: batchSize}, func(items interface{}) error { + got = append(got, items.([]int)) + return nil + }) + for i := 0; i < (batchSize - 1); i++ { + b.AddNoWait(i) + } + + // Ensure that we've received nothing + if len(got) > 0 { + t.Errorf("got batch unexpectedly: %+v", got) + } + + b.Shutdown() + + want := [][]int{{0, 1}} + if !cmp.Equal(got, want) { + t.Errorf("got %+v, want %+v on shutdown", got, want) + } +} + func TestSaturation(t *testing.T) { // Verify that under high load the maximum number of handlers are running. ctx := context.Background()