Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: genericize the batch controller #8793

Merged
merged 4 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 236 additions & 0 deletions client/batch_controller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
// Copyright 2024 TiKV Project Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pd

import (
"context"
"time"

"github.com/prometheus/client_golang/prometheus"
)

// Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4).
const defaultBestBatchSize = 8

// finisherFunc is used to finish a request, it accepts the index of the request in the batch, the request itself and an error.
type finisherFunc[T any] func(int, T, error)

type batchController[T any] struct {
maxBatchSize int
// bestBatchSize is a dynamic size that changed based on the current batch effect.
bestBatchSize int

collectedRequests []T
collectedRequestCount int

// The finisher function to cancel collected requests when an internal error occurs.
finisher finisherFunc[T]
// The observer to record the best batch size.
bestBatchObserver prometheus.Histogram
// The time after getting the first request and the token, and before performing extra batching.
extraBatchingStartTime time.Time
}

func newBatchController[T any](maxBatchSize int, finisher finisherFunc[T], bestBatchObserver prometheus.Histogram) *batchController[T] {
return &batchController[T]{
maxBatchSize: maxBatchSize,
bestBatchSize: defaultBestBatchSize,
collectedRequests: make([]T, maxBatchSize+1),
collectedRequestCount: 0,
finisher: finisher,
bestBatchObserver: bestBatchObserver,
}
}

// fetchPendingRequests will start a new round of the batch collecting from the channel.
// It returns nil error if everything goes well, otherwise a non-nil error which means we should stop the service.
// It's guaranteed that if this function failed after collecting some requests, then these requests will be cancelled
// when the function returns, so the caller don't need to clear them manually.
func (bc *batchController[T]) fetchPendingRequests(ctx context.Context, requestCh <-chan T, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) {
var tokenAcquired bool
defer func() {
if errRet != nil {
// Something went wrong when collecting a batch of requests. Release the token and cancel collected requests
// if any.
if tokenAcquired {
tokenCh <- struct{}{}
}
bc.finishCollectedRequests(bc.finisher, errRet)
}
}()

// Wait until BOTH the first request and the token have arrived.
// TODO: `bc.collectedRequestCount` should never be non-empty here. Consider do assertion here.
bc.collectedRequestCount = 0
for {
// If the batch size reaches the maxBatchSize limit but the token haven't arrived yet, don't receive more
// requests, and return when token is ready.
if bc.collectedRequestCount >= bc.maxBatchSize && !tokenAcquired {
select {
case <-ctx.Done():
return ctx.Err()
rleungx marked this conversation as resolved.
Show resolved Hide resolved
case <-tokenCh:
return nil

Check warning on line 85 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L81-L85

Added lines #L81 - L85 were not covered by tests
}
}

select {
case <-ctx.Done():
return ctx.Err()
case req := <-requestCh:
// Start to batch when the first request arrives.
bc.pushRequest(req)
// A request arrives but the token is not ready yet. Continue waiting, and also allowing collecting the next
// request if it arrives.
continue
case <-tokenCh:
tokenAcquired = true
}

// The token is ready. If the first request didn't arrive, wait for it.
if bc.collectedRequestCount == 0 {
select {
case <-ctx.Done():
return ctx.Err()
case firstRequest := <-requestCh:
bc.pushRequest(firstRequest)
}
}

// Both token and the first request have arrived.
break
}

bc.extraBatchingStartTime = time.Now()

// This loop is for trying best to collect more requests, so we use `bc.maxBatchSize` here.
fetchPendingRequestsLoop:
for bc.collectedRequestCount < bc.maxBatchSize {
select {
case req := <-requestCh:
bc.pushRequest(req)
case <-ctx.Done():
return ctx.Err()
default:
break fetchPendingRequestsLoop
}
}

// Check whether we should fetch more pending requests from the channel.
if bc.collectedRequestCount >= bc.maxBatchSize || maxBatchWaitInterval <= 0 {
JmPotato marked this conversation as resolved.
Show resolved Hide resolved
return nil
}

// Fetches more pending requests from the channel.
// Try to collect `bc.bestBatchSize` requests, or wait `maxBatchWaitInterval`
// when `bc.collectedRequestCount` is less than the `bc.bestBatchSize`.
if bc.collectedRequestCount < bc.bestBatchSize {
after := time.NewTimer(maxBatchWaitInterval)
defer after.Stop()
for bc.collectedRequestCount < bc.bestBatchSize {
select {
case req := <-requestCh:
bc.pushRequest(req)
case <-ctx.Done():
return ctx.Err()
case <-after.C:
return nil

Check warning on line 149 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L139-L149

Added lines #L139 - L149 were not covered by tests
}
}
}

// Do an additional non-block try. Here we test the length with `bc.maxBatchSize` instead
// of `bc.bestBatchSize` because trying best to fetch more requests is necessary so that
// we can adjust the `bc.bestBatchSize` dynamically later.
for bc.collectedRequestCount < bc.maxBatchSize {
select {
case req := <-requestCh:
bc.pushRequest(req)
case <-ctx.Done():
return ctx.Err()
default:
return nil

Check warning on line 164 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L157-L164

Added lines #L157 - L164 were not covered by tests
}
}
return nil

Check warning on line 167 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L167

Added line #L167 was not covered by tests
}

// fetchRequestsWithTimer tries to fetch requests until the given timer ticks. The caller must set the timer properly
// before calling this function.
func (bc *batchController[T]) fetchRequestsWithTimer(ctx context.Context, requestCh <-chan T, timer *time.Timer) error {
batchingLoop:
for bc.collectedRequestCount < bc.maxBatchSize {
select {
case <-ctx.Done():
return ctx.Err()
case req := <-requestCh:
bc.pushRequest(req)
case <-timer.C:
break batchingLoop

Check warning on line 181 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L172-L181

Added lines #L172 - L181 were not covered by tests
}
}

// Try to collect more requests in non-blocking way.
nonWaitingBatchLoop:
for bc.collectedRequestCount < bc.maxBatchSize {
select {
case <-ctx.Done():
return ctx.Err()
case req := <-requestCh:
bc.pushRequest(req)
default:
break nonWaitingBatchLoop

Check warning on line 194 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L186-L194

Added lines #L186 - L194 were not covered by tests
}
}

return nil

Check warning on line 198 in client/batch_controller.go

View check run for this annotation

Codecov / codecov/patch

client/batch_controller.go#L198

Added line #L198 was not covered by tests
}

func (bc *batchController[T]) pushRequest(req T) {
bc.collectedRequests[bc.collectedRequestCount] = req
bc.collectedRequestCount++
}

func (bc *batchController[T]) getCollectedRequests() []T {
return bc.collectedRequests[:bc.collectedRequestCount]
}

// adjustBestBatchSize stabilizes the latency with the AIAD algorithm.
func (bc *batchController[T]) adjustBestBatchSize() {
if bc.bestBatchObserver != nil {
bc.bestBatchObserver.Observe(float64(bc.bestBatchSize))
}
length := bc.collectedRequestCount
if length < bc.bestBatchSize && bc.bestBatchSize > 1 {
// Waits too long to collect requests, reduce the target batch size.
bc.bestBatchSize--
} else if length > bc.bestBatchSize+4 /* Hard-coded number, in order to make `bc.bestBatchSize` stable */ &&
bc.bestBatchSize < bc.maxBatchSize {
bc.bestBatchSize++
}
}

func (bc *batchController[T]) finishCollectedRequests(finisher finisherFunc[T], err error) {
if finisher == nil {
finisher = bc.finisher
}
if finisher != nil {
for i := range bc.collectedRequestCount {
finisher(i, bc.collectedRequests[i], err)
}
}
// Prevent the finished requests from being processed again.
bc.collectedRequestCount = 0
}
83 changes: 83 additions & 0 deletions client/batch_controller_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2024 TiKV Project Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pd

import (
"context"
"testing"

"github.com/stretchr/testify/require"
)

func TestAdjustBestBatchSize(t *testing.T) {
re := require.New(t)
bc := newBatchController[int](20, nil, nil)
re.Equal(defaultBestBatchSize, bc.bestBatchSize)
bc.adjustBestBatchSize()
re.Equal(defaultBestBatchSize-1, bc.bestBatchSize)
// Clear the collected requests.
bc.finishCollectedRequests(nil, nil)
// Push 10 requests - do not increase the best batch size.
for i := range 10 {
bc.pushRequest(i)
}
bc.adjustBestBatchSize()
re.Equal(defaultBestBatchSize-1, bc.bestBatchSize)
bc.finishCollectedRequests(nil, nil)
// Push 15 requests, increase the best batch size.
for i := range 15 {
bc.pushRequest(i)
}
bc.adjustBestBatchSize()
re.Equal(defaultBestBatchSize, bc.bestBatchSize)
bc.finishCollectedRequests(nil, nil)
}

type testRequest struct {
idx int
err error
}

func TestFinishCollectedRequests(t *testing.T) {
re := require.New(t)
bc := newBatchController[*testRequest](20, nil, nil)
// Finish with zero request count.
re.Zero(bc.collectedRequestCount)
bc.finishCollectedRequests(nil, nil)
re.Zero(bc.collectedRequestCount)
// Finish with non-zero request count.
requests := make([]*testRequest, 10)
for i := range 10 {
requests[i] = &testRequest{}
bc.pushRequest(requests[i])
}
re.Equal(10, bc.collectedRequestCount)
bc.finishCollectedRequests(nil, nil)
re.Zero(bc.collectedRequestCount)
// Finish with custom finisher.
for i := range 10 {
requests[i] = &testRequest{}
bc.pushRequest(requests[i])
}
bc.finishCollectedRequests(func(idx int, tr *testRequest, err error) {
tr.idx = idx
tr.err = err
}, context.Canceled)
re.Zero(bc.collectedRequestCount)
for i := range 10 {
re.Equal(i, requests[i].idx)
re.Equal(context.Canceled, requests[i].err)
}
}
Loading