Skip to content

Commit

Permalink
Add the batch controller test cases
Browse files Browse the repository at this point in the history
Signed-off-by: JmPotato <[email protected]>
  • Loading branch information
JmPotato committed Nov 8, 2024
1 parent b4fcc67 commit f240187
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 13 deletions.
26 changes: 18 additions & 8 deletions client/batch_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ import (
"github.com/prometheus/client_golang/prometheus"
)

// Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4).
const defaultBestBatchSize = 8

// finisherFunc is used to finish a request, it accepts the index of the request in the batch, the request itself and an error.
type finisherFunc[T any] func(int, T, error)

type batchController[T any] struct {
maxBatchSize int
// bestBatchSize is a dynamic size that changed based on the current batch effect.
Expand All @@ -30,17 +36,17 @@ type batchController[T any] struct {
collectedRequestCount int

// The finisher function to cancel collected requests when an internal error occurs.
cancelFinisher func(int, T)
cancelFinisher finisherFunc[T]
// The observer to record the best batch size.
bestBatchObserver prometheus.Histogram
// The time after getting the first request and the token, and before performing extra batching.
extraBatchingStartTime time.Time
}

func newBatchController[T any](maxBatchSize int, cancelFinisher func(int, T), bestBatchObserver prometheus.Histogram) *batchController[T] {
func newBatchController[T any](maxBatchSize int, cancelFinisher finisherFunc[T], bestBatchObserver prometheus.Histogram) *batchController[T] {
return &batchController[T]{
maxBatchSize: maxBatchSize,
bestBatchSize: 8, /* Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4) */
bestBatchSize: defaultBestBatchSize,
collectedRequests: make([]T, maxBatchSize+1),
collectedRequestCount: 0,
cancelFinisher: cancelFinisher,
Expand All @@ -61,7 +67,7 @@ func (bc *batchController[T]) fetchPendingRequests(ctx context.Context, requestC
if tokenAcquired {
tokenCh <- struct{}{}
}
bc.finishCollectedRequests(bc.cancelFinisher)
bc.finishCollectedRequests(bc.cancelFinisher, errRet)
}
}()

Expand Down Expand Up @@ -203,7 +209,9 @@ func (bc *batchController[T]) getCollectedRequests() []T {

// adjustBestBatchSize stabilizes the latency with the AIAD algorithm.
func (bc *batchController[T]) adjustBestBatchSize() {
bc.bestBatchObserver.Observe(float64(bc.bestBatchSize))
if bc.bestBatchObserver != nil {
bc.bestBatchObserver.Observe(float64(bc.bestBatchSize))
}
length := bc.collectedRequestCount
if length < bc.bestBatchSize && bc.bestBatchSize > 1 {
// Waits too long to collect requests, reduce the target batch size.
Expand All @@ -214,12 +222,14 @@ func (bc *batchController[T]) adjustBestBatchSize() {
}
}

func (bc *batchController[T]) finishCollectedRequests(finisher func(int, T)) {
func (bc *batchController[T]) finishCollectedRequests(finisher finisherFunc[T], err error) {
if finisher == nil {
finisher = bc.cancelFinisher
}
for i := range bc.collectedRequestCount {
finisher(i, bc.collectedRequests[i])
if finisher != nil {
for i := range bc.collectedRequestCount {
finisher(i, bc.collectedRequests[i], err)
}
}
// Prevent the finished requests from being processed again.
bc.collectedRequestCount = 0
Expand Down
83 changes: 83 additions & 0 deletions client/batch_controller_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2024 TiKV Project Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pd

import (
"context"
"testing"

"github.com/stretchr/testify/require"
)

func TestAdjustBestBatchSize(t *testing.T) {
re := require.New(t)
bc := newBatchController[int](20, nil, nil)
re.Equal(defaultBestBatchSize, bc.bestBatchSize)
bc.adjustBestBatchSize()
re.Equal(defaultBestBatchSize-1, bc.bestBatchSize)
// Clear the collected requests.
bc.finishCollectedRequests(nil, nil)
// Push 10 requests - do not increase the best batch size.
for i := range 10 {
bc.pushRequest(i)
}
bc.adjustBestBatchSize()
re.Equal(defaultBestBatchSize-1, bc.bestBatchSize)
bc.finishCollectedRequests(nil, nil)
// Push 15 requests, increase the best batch size.
for i := range 15 {
bc.pushRequest(i)
}
bc.adjustBestBatchSize()
re.Equal(defaultBestBatchSize, bc.bestBatchSize)
bc.finishCollectedRequests(nil, nil)
}

type testRequest struct {
idx int
err error
}

func TestFinishCollectedRequests(t *testing.T) {
re := require.New(t)
bc := newBatchController[*testRequest](20, nil, nil)
// Finish with zero request count.
re.Zero(bc.collectedRequestCount)
bc.finishCollectedRequests(nil, nil)
re.Zero(bc.collectedRequestCount)
// Finish with non-zero request count.
requests := make([]*testRequest, 10)
for i := range 10 {
requests[i] = &testRequest{}
bc.pushRequest(requests[i])
}
re.Equal(10, bc.collectedRequestCount)
bc.finishCollectedRequests(nil, nil)
re.Zero(bc.collectedRequestCount)
// Finish with custom finisher.
for i := range 10 {
requests[i] = &testRequest{}
bc.pushRequest(requests[i])
}
bc.finishCollectedRequests(func(idx int, tr *testRequest, err error) {
tr.idx = idx
tr.err = err
}, context.Canceled)
re.Zero(bc.collectedRequestCount)
for i := range 10 {
re.Equal(i, requests[i].idx)
re.Equal(context.Canceled, requests[i].err)
}
}
10 changes: 5 additions & 5 deletions client/tso_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func newTSODispatcher(
New: func() any {
return newBatchController[*tsoRequest](
maxBatchSize*2,
tsoRequestFinisherFactory(0, 0, 0, invalidStreamID, nil),
tsoRequestFinisherFactory(0, 0, 0, invalidStreamID),
tsoBestBatchSize,
)
},
Expand Down Expand Up @@ -614,8 +614,8 @@ func (td *tsoDispatcher) processRequests(
return nil
}

func tsoRequestFinisherFactory(physical, firstLogical int64, suffixBits uint32, streamID string, err error) func(int, *tsoRequest) {
return func(idx int, tsoReq *tsoRequest) {
func tsoRequestFinisherFactory(physical, firstLogical int64, suffixBits uint32, streamID string) finisherFunc[*tsoRequest] {
return func(idx int, tsoReq *tsoRequest, err error) {
// Retrieve the request context before the request is done to trace without race.
requestCtx := tsoReq.requestCtx
tsoReq.physical, tsoReq.logical = physical, tsoutil.AddLogical(firstLogical, int64(idx), suffixBits)
Expand All @@ -627,12 +627,12 @@ func tsoRequestFinisherFactory(physical, firstLogical int64, suffixBits uint32,

func (td *tsoDispatcher) cancelCollectedRequests(tbc *batchController[*tsoRequest], streamID string, err error) {
td.tokenCh <- struct{}{}
tbc.finishCollectedRequests(tsoRequestFinisherFactory(0, 0, 0, streamID, err))
tbc.finishCollectedRequests(tsoRequestFinisherFactory(0, 0, 0, streamID), err)
}

func (td *tsoDispatcher) doneCollectedRequests(tbc *batchController[*tsoRequest], physical, firstLogical int64, suffixBits uint32, streamID string) {
td.tokenCh <- struct{}{}
tbc.finishCollectedRequests(tsoRequestFinisherFactory(physical, firstLogical, suffixBits, streamID, nil))
tbc.finishCollectedRequests(tsoRequestFinisherFactory(physical, firstLogical, suffixBits, streamID), nil)
}

// checkMonotonicity checks whether the monotonicity of the TSO allocation is violated.
Expand Down

0 comments on commit f240187

Please sign in to comment.