From f2401877e2bd27f8885841c16ab8b18a949c0ca8 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Fri, 8 Nov 2024 17:25:30 +0800 Subject: [PATCH] Add the batch controller test cases Signed-off-by: JmPotato --- client/batch_controller.go | 26 +++++++---- client/batch_controller_test.go | 83 +++++++++++++++++++++++++++++++++ client/tso_dispatcher.go | 10 ++-- 3 files changed, 106 insertions(+), 13 deletions(-) create mode 100644 client/batch_controller_test.go diff --git a/client/batch_controller.go b/client/batch_controller.go index 04de66a30948..f48b1d4d9385 100644 --- a/client/batch_controller.go +++ b/client/batch_controller.go @@ -21,6 +21,12 @@ import ( "github.com/prometheus/client_golang/prometheus" ) +// Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4). +const defaultBestBatchSize = 8 + +// finisherFunc is used to finish a request, it accepts the index of the request in the batch, the request itself and an error. +type finisherFunc[T any] func(int, T, error) + type batchController[T any] struct { maxBatchSize int // bestBatchSize is a dynamic size that changed based on the current batch effect. @@ -30,17 +36,17 @@ type batchController[T any] struct { collectedRequestCount int // The finisher function to cancel collected requests when an internal error occurs. - cancelFinisher func(int, T) + cancelFinisher finisherFunc[T] // The observer to record the best batch size. bestBatchObserver prometheus.Histogram // The time after getting the first request and the token, and before performing extra batching. extraBatchingStartTime time.Time } -func newBatchController[T any](maxBatchSize int, cancelFinisher func(int, T), bestBatchObserver prometheus.Histogram) *batchController[T] { +func newBatchController[T any](maxBatchSize int, cancelFinisher finisherFunc[T], bestBatchObserver prometheus.Histogram) *batchController[T] { return &batchController[T]{ maxBatchSize: maxBatchSize, - bestBatchSize: 8, /* Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4) */ + bestBatchSize: defaultBestBatchSize, collectedRequests: make([]T, maxBatchSize+1), collectedRequestCount: 0, cancelFinisher: cancelFinisher, @@ -61,7 +67,7 @@ func (bc *batchController[T]) fetchPendingRequests(ctx context.Context, requestC if tokenAcquired { tokenCh <- struct{}{} } - bc.finishCollectedRequests(bc.cancelFinisher) + bc.finishCollectedRequests(bc.cancelFinisher, errRet) } }() @@ -203,7 +209,9 @@ func (bc *batchController[T]) getCollectedRequests() []T { // adjustBestBatchSize stabilizes the latency with the AIAD algorithm. func (bc *batchController[T]) adjustBestBatchSize() { - bc.bestBatchObserver.Observe(float64(bc.bestBatchSize)) + if bc.bestBatchObserver != nil { + bc.bestBatchObserver.Observe(float64(bc.bestBatchSize)) + } length := bc.collectedRequestCount if length < bc.bestBatchSize && bc.bestBatchSize > 1 { // Waits too long to collect requests, reduce the target batch size. @@ -214,12 +222,14 @@ func (bc *batchController[T]) adjustBestBatchSize() { } } -func (bc *batchController[T]) finishCollectedRequests(finisher func(int, T)) { +func (bc *batchController[T]) finishCollectedRequests(finisher finisherFunc[T], err error) { if finisher == nil { finisher = bc.cancelFinisher } - for i := range bc.collectedRequestCount { - finisher(i, bc.collectedRequests[i]) + if finisher != nil { + for i := range bc.collectedRequestCount { + finisher(i, bc.collectedRequests[i], err) + } } // Prevent the finished requests from being processed again. bc.collectedRequestCount = 0 diff --git a/client/batch_controller_test.go b/client/batch_controller_test.go new file mode 100644 index 000000000000..b4a8a04dc880 --- /dev/null +++ b/client/batch_controller_test.go @@ -0,0 +1,83 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAdjustBestBatchSize(t *testing.T) { + re := require.New(t) + bc := newBatchController[int](20, nil, nil) + re.Equal(defaultBestBatchSize, bc.bestBatchSize) + bc.adjustBestBatchSize() + re.Equal(defaultBestBatchSize-1, bc.bestBatchSize) + // Clear the collected requests. + bc.finishCollectedRequests(nil, nil) + // Push 10 requests - do not increase the best batch size. + for i := range 10 { + bc.pushRequest(i) + } + bc.adjustBestBatchSize() + re.Equal(defaultBestBatchSize-1, bc.bestBatchSize) + bc.finishCollectedRequests(nil, nil) + // Push 15 requests, increase the best batch size. + for i := range 15 { + bc.pushRequest(i) + } + bc.adjustBestBatchSize() + re.Equal(defaultBestBatchSize, bc.bestBatchSize) + bc.finishCollectedRequests(nil, nil) +} + +type testRequest struct { + idx int + err error +} + +func TestFinishCollectedRequests(t *testing.T) { + re := require.New(t) + bc := newBatchController[*testRequest](20, nil, nil) + // Finish with zero request count. + re.Zero(bc.collectedRequestCount) + bc.finishCollectedRequests(nil, nil) + re.Zero(bc.collectedRequestCount) + // Finish with non-zero request count. + requests := make([]*testRequest, 10) + for i := range 10 { + requests[i] = &testRequest{} + bc.pushRequest(requests[i]) + } + re.Equal(10, bc.collectedRequestCount) + bc.finishCollectedRequests(nil, nil) + re.Zero(bc.collectedRequestCount) + // Finish with custom finisher. + for i := range 10 { + requests[i] = &testRequest{} + bc.pushRequest(requests[i]) + } + bc.finishCollectedRequests(func(idx int, tr *testRequest, err error) { + tr.idx = idx + tr.err = err + }, context.Canceled) + re.Zero(bc.collectedRequestCount) + for i := range 10 { + re.Equal(i, requests[i].idx) + re.Equal(context.Canceled, requests[i].err) + } +} diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 00bc9dff7791..3b9e7b070ce1 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -128,7 +128,7 @@ func newTSODispatcher( New: func() any { return newBatchController[*tsoRequest]( maxBatchSize*2, - tsoRequestFinisherFactory(0, 0, 0, invalidStreamID, nil), + tsoRequestFinisherFactory(0, 0, 0, invalidStreamID), tsoBestBatchSize, ) }, @@ -614,8 +614,8 @@ func (td *tsoDispatcher) processRequests( return nil } -func tsoRequestFinisherFactory(physical, firstLogical int64, suffixBits uint32, streamID string, err error) func(int, *tsoRequest) { - return func(idx int, tsoReq *tsoRequest) { +func tsoRequestFinisherFactory(physical, firstLogical int64, suffixBits uint32, streamID string) finisherFunc[*tsoRequest] { + return func(idx int, tsoReq *tsoRequest, err error) { // Retrieve the request context before the request is done to trace without race. requestCtx := tsoReq.requestCtx tsoReq.physical, tsoReq.logical = physical, tsoutil.AddLogical(firstLogical, int64(idx), suffixBits) @@ -627,12 +627,12 @@ func tsoRequestFinisherFactory(physical, firstLogical int64, suffixBits uint32, func (td *tsoDispatcher) cancelCollectedRequests(tbc *batchController[*tsoRequest], streamID string, err error) { td.tokenCh <- struct{}{} - tbc.finishCollectedRequests(tsoRequestFinisherFactory(0, 0, 0, streamID, err)) + tbc.finishCollectedRequests(tsoRequestFinisherFactory(0, 0, 0, streamID), err) } func (td *tsoDispatcher) doneCollectedRequests(tbc *batchController[*tsoRequest], physical, firstLogical int64, suffixBits uint32, streamID string) { td.tokenCh <- struct{}{} - tbc.finishCollectedRequests(tsoRequestFinisherFactory(physical, firstLogical, suffixBits, streamID, nil)) + tbc.finishCollectedRequests(tsoRequestFinisherFactory(physical, firstLogical, suffixBits, streamID), nil) } // checkMonotonicity checks whether the monotonicity of the TSO allocation is violated.