From 176ab2364abff3a3729515be23becf0cb2798cf1 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 26 Nov 2024 15:40:11 +0800 Subject: [PATCH] client: separate the TSO client implementation (#8848) ref tikv/pd#8690 Separate the TSO client implementation. Signed-off-by: JmPotato Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- client/{ => batch}/batch_controller.go | 48 +++++++---- client/{ => batch}/batch_controller_test.go | 24 +++--- client/client.go | 12 +-- client/client_test.go | 28 ------- client/{ => clients/tso}/tso_client.go | 80 ++++++++++--------- client/{ => clients/tso}/tso_dispatcher.go | 50 ++++++------ .../{ => clients/tso}/tso_dispatcher_test.go | 46 +++++------ client/{ => clients/tso}/tso_request.go | 28 ++++--- client/clients/tso/tso_request_test.go | 50 ++++++++++++ client/{ => clients/tso}/tso_stream.go | 24 +++--- client/{ => clients/tso}/tso_stream_test.go | 2 +- client/constants/constants.go | 7 ++ client/inner_client.go | 33 ++++---- client/resource_manager_client.go | 5 +- tests/integrations/client/client_test.go | 16 ++-- tests/integrations/tso/client_test.go | 7 +- 16 files changed, 264 insertions(+), 196 deletions(-) rename client/{ => batch}/batch_controller.go (79%) rename client/{ => batch}/batch_controller_test.go (81%) rename client/{ => clients/tso}/tso_client.go (90%) rename client/{ => clients/tso}/tso_dispatcher.go (95%) rename client/{ => clients/tso}/tso_dispatcher_test.go (83%) rename client/{ => clients/tso}/tso_request.go (78%) create mode 100644 client/clients/tso/tso_request_test.go rename client/{ => clients/tso}/tso_stream.go (95%) rename client/{ => clients/tso}/tso_stream_test.go (99%) diff --git a/client/batch_controller.go b/client/batch/batch_controller.go similarity index 79% rename from client/batch_controller.go rename to client/batch/batch_controller.go index a19c3181ccd..32f0aaba1ae 100644 --- a/client/batch_controller.go +++ b/client/batch/batch_controller.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pd +package batch import ( "context" @@ -24,10 +24,11 @@ import ( // Starting from a low value is necessary because we need to make sure it will be converged to (current_batch_size - 4). const defaultBestBatchSize = 8 -// finisherFunc is used to finish a request, it accepts the index of the request in the batch, the request itself and an error. -type finisherFunc[T any] func(int, T, error) +// FinisherFunc is used to finish a request, it accepts the index of the request in the batch, the request itself and an error. +type FinisherFunc[T any] func(int, T, error) -type batchController[T any] struct { +// Controller is used to batch requests. +type Controller[T any] struct { maxBatchSize int // bestBatchSize is a dynamic size that changed based on the current batch effect. bestBatchSize int @@ -36,15 +37,16 @@ type batchController[T any] struct { collectedRequestCount int // The finisher function to cancel collected requests when an internal error occurs. - finisher finisherFunc[T] + finisher FinisherFunc[T] // The observer to record the best batch size. bestBatchObserver prometheus.Histogram // The time after getting the first request and the token, and before performing extra batching. extraBatchingStartTime time.Time } -func newBatchController[T any](maxBatchSize int, finisher finisherFunc[T], bestBatchObserver prometheus.Histogram) *batchController[T] { - return &batchController[T]{ +// NewController creates a new batch controller. +func NewController[T any](maxBatchSize int, finisher FinisherFunc[T], bestBatchObserver prometheus.Histogram) *Controller[T] { + return &Controller[T]{ maxBatchSize: maxBatchSize, bestBatchSize: defaultBestBatchSize, collectedRequests: make([]T, maxBatchSize+1), @@ -54,11 +56,11 @@ func newBatchController[T any](maxBatchSize int, finisher finisherFunc[T], bestB } } -// fetchPendingRequests will start a new round of the batch collecting from the channel. +// FetchPendingRequests will start a new round of the batch collecting from the channel. // It returns nil error if everything goes well, otherwise a non-nil error which means we should stop the service. // It's guaranteed that if this function failed after collecting some requests, then these requests will be cancelled // when the function returns, so the caller don't need to clear them manually. -func (bc *batchController[T]) fetchPendingRequests(ctx context.Context, requestCh <-chan T, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) { +func (bc *Controller[T]) FetchPendingRequests(ctx context.Context, requestCh <-chan T, tokenCh chan struct{}, maxBatchWaitInterval time.Duration) (errRet error) { var tokenAcquired bool defer func() { if errRet != nil { @@ -67,7 +69,7 @@ func (bc *batchController[T]) fetchPendingRequests(ctx context.Context, requestC if tokenAcquired { tokenCh <- struct{}{} } - bc.finishCollectedRequests(bc.finisher, errRet) + bc.FinishCollectedRequests(bc.finisher, errRet) } }() @@ -167,9 +169,9 @@ fetchPendingRequestsLoop: return nil } -// fetchRequestsWithTimer tries to fetch requests until the given timer ticks. The caller must set the timer properly +// FetchRequestsWithTimer tries to fetch requests until the given timer ticks. The caller must set the timer properly // before calling this function. -func (bc *batchController[T]) fetchRequestsWithTimer(ctx context.Context, requestCh <-chan T, timer *time.Timer) error { +func (bc *Controller[T]) FetchRequestsWithTimer(ctx context.Context, requestCh <-chan T, timer *time.Timer) error { batchingLoop: for bc.collectedRequestCount < bc.maxBatchSize { select { @@ -198,17 +200,23 @@ nonWaitingBatchLoop: return nil } -func (bc *batchController[T]) pushRequest(req T) { +func (bc *Controller[T]) pushRequest(req T) { bc.collectedRequests[bc.collectedRequestCount] = req bc.collectedRequestCount++ } -func (bc *batchController[T]) getCollectedRequests() []T { +// GetCollectedRequests returns the collected requests. +func (bc *Controller[T]) GetCollectedRequests() []T { return bc.collectedRequests[:bc.collectedRequestCount] } -// adjustBestBatchSize stabilizes the latency with the AIAD algorithm. -func (bc *batchController[T]) adjustBestBatchSize() { +// GetCollectedRequestCount returns the number of collected requests. +func (bc *Controller[T]) GetCollectedRequestCount() int { + return bc.collectedRequestCount +} + +// AdjustBestBatchSize stabilizes the latency with the AIAD algorithm. +func (bc *Controller[T]) AdjustBestBatchSize() { if bc.bestBatchObserver != nil { bc.bestBatchObserver.Observe(float64(bc.bestBatchSize)) } @@ -222,7 +230,8 @@ func (bc *batchController[T]) adjustBestBatchSize() { } } -func (bc *batchController[T]) finishCollectedRequests(finisher finisherFunc[T], err error) { +// FinishCollectedRequests finishes the collected requests. +func (bc *Controller[T]) FinishCollectedRequests(finisher FinisherFunc[T], err error) { if finisher == nil { finisher = bc.finisher } @@ -234,3 +243,8 @@ func (bc *batchController[T]) finishCollectedRequests(finisher finisherFunc[T], // Prevent the finished requests from being processed again. bc.collectedRequestCount = 0 } + +// GetExtraBatchingStartTime returns the extra batching start time. +func (bc *Controller[T]) GetExtraBatchingStartTime() time.Time { + return bc.extraBatchingStartTime +} diff --git a/client/batch_controller_test.go b/client/batch/batch_controller_test.go similarity index 81% rename from client/batch_controller_test.go rename to client/batch/batch_controller_test.go index b4a8a04dc88..7c9ffa6944f 100644 --- a/client/batch_controller_test.go +++ b/client/batch/batch_controller_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pd +package batch import ( "context" @@ -23,26 +23,26 @@ import ( func TestAdjustBestBatchSize(t *testing.T) { re := require.New(t) - bc := newBatchController[int](20, nil, nil) + bc := NewController[int](20, nil, nil) re.Equal(defaultBestBatchSize, bc.bestBatchSize) - bc.adjustBestBatchSize() + bc.AdjustBestBatchSize() re.Equal(defaultBestBatchSize-1, bc.bestBatchSize) // Clear the collected requests. - bc.finishCollectedRequests(nil, nil) + bc.FinishCollectedRequests(nil, nil) // Push 10 requests - do not increase the best batch size. for i := range 10 { bc.pushRequest(i) } - bc.adjustBestBatchSize() + bc.AdjustBestBatchSize() re.Equal(defaultBestBatchSize-1, bc.bestBatchSize) - bc.finishCollectedRequests(nil, nil) + bc.FinishCollectedRequests(nil, nil) // Push 15 requests, increase the best batch size. for i := range 15 { bc.pushRequest(i) } - bc.adjustBestBatchSize() + bc.AdjustBestBatchSize() re.Equal(defaultBestBatchSize, bc.bestBatchSize) - bc.finishCollectedRequests(nil, nil) + bc.FinishCollectedRequests(nil, nil) } type testRequest struct { @@ -52,10 +52,10 @@ type testRequest struct { func TestFinishCollectedRequests(t *testing.T) { re := require.New(t) - bc := newBatchController[*testRequest](20, nil, nil) + bc := NewController[*testRequest](20, nil, nil) // Finish with zero request count. re.Zero(bc.collectedRequestCount) - bc.finishCollectedRequests(nil, nil) + bc.FinishCollectedRequests(nil, nil) re.Zero(bc.collectedRequestCount) // Finish with non-zero request count. requests := make([]*testRequest, 10) @@ -64,14 +64,14 @@ func TestFinishCollectedRequests(t *testing.T) { bc.pushRequest(requests[i]) } re.Equal(10, bc.collectedRequestCount) - bc.finishCollectedRequests(nil, nil) + bc.FinishCollectedRequests(nil, nil) re.Zero(bc.collectedRequestCount) // Finish with custom finisher. for i := range 10 { requests[i] = &testRequest{} bc.pushRequest(requests[i]) } - bc.finishCollectedRequests(func(idx int, tr *testRequest, err error) { + bc.FinishCollectedRequests(func(idx int, tr *testRequest, err error) { tr.idx = idx tr.err = err }, context.Canceled) diff --git a/client/client.go b/client/client.go index 6781182a44b..6fcb03576bb 100644 --- a/client/client.go +++ b/client/client.go @@ -33,6 +33,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/tikv/pd/client/caller" "github.com/tikv/pd/client/clients/metastorage" + "github.com/tikv/pd/client/clients/tso" "github.com/tikv/pd/client/constants" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" @@ -140,8 +141,7 @@ type RPCClient interface { // on your needs. WithCallerComponent(callerComponent caller.Component) RPCClient - // TSOClient is the TSO client. - TSOClient + tso.Client metastorage.Client // KeyspaceClient manages keyspace metadata. KeyspaceClient @@ -179,7 +179,7 @@ type serviceModeKeeper struct { // triggering service mode switching concurrently. sync.RWMutex serviceMode pdpb.ServiceMode - tsoClient *tsoClient + tsoClient *tso.Cli tsoSvcDiscovery sd.ServiceDiscovery } @@ -191,7 +191,7 @@ func (k *serviceModeKeeper) close() { k.tsoSvcDiscovery.Close() fallthrough case pdpb.ServiceMode_PD_SVC_MODE: - k.tsoClient.close() + k.tsoClient.Close() case pdpb.ServiceMode_UNKNOWN_SVC_MODE: } } @@ -557,7 +557,7 @@ func (c *client) getClientAndContext(ctx context.Context) (pdpb.PDClient, contex } // GetTSAsync implements the TSOClient interface. -func (c *client) GetTSAsync(ctx context.Context) TSFuture { +func (c *client) GetTSAsync(ctx context.Context) tso.TSFuture { defer trace.StartRegion(ctx, "pdclient.GetTSAsync").End() if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span = span.Tracer().StartSpan("pdclient.GetTSAsync", opentracing.ChildOf(span.Context())) @@ -570,7 +570,7 @@ func (c *client) GetTSAsync(ctx context.Context) TSFuture { // // Deprecated: Local TSO will be completely removed in the future. Currently, regardless of the // parameters passed in, this method will default to returning the global TSO. -func (c *client) GetLocalTSAsync(ctx context.Context, _ string) TSFuture { +func (c *client) GetLocalTSAsync(ctx context.Context, _ string) tso.TSFuture { return c.GetTSAsync(ctx) } diff --git a/client/client_test.go b/client/client_test.go index 8b4cc2242ca..36f6bb1b648 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -19,7 +19,6 @@ import ( "testing" "time" - "github.com/pingcap/errors" "github.com/stretchr/testify/require" "github.com/tikv/pd/client/caller" "github.com/tikv/pd/client/opt" @@ -62,30 +61,3 @@ func TestClientWithRetry(t *testing.T) { re.Error(err) re.Less(time.Since(start), time.Second*10) } - -func TestTsoRequestWait(t *testing.T) { - re := require.New(t) - ctx, cancel := context.WithCancel(context.Background()) - req := &tsoRequest{ - done: make(chan error, 1), - physical: 0, - logical: 0, - requestCtx: context.TODO(), - clientCtx: ctx, - } - cancel() - _, _, err := req.Wait() - re.ErrorIs(errors.Cause(err), context.Canceled) - - ctx, cancel = context.WithCancel(context.Background()) - req = &tsoRequest{ - done: make(chan error, 1), - physical: 0, - logical: 0, - requestCtx: ctx, - clientCtx: context.TODO(), - } - cancel() - _, _, err = req.Wait() - re.ErrorIs(errors.Cause(err), context.Canceled) -} diff --git a/client/tso_client.go b/client/clients/tso/tso_client.go similarity index 90% rename from client/tso_client.go rename to client/clients/tso/tso_client.go index 1d0a6385647..4343c36ca7f 100644 --- a/client/tso_client.go +++ b/client/clients/tso/tso_client.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pd +package tso import ( "context" @@ -25,6 +25,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/log" + "github.com/tikv/pd/client/constants" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" @@ -41,13 +42,12 @@ import ( const ( // defaultMaxTSOBatchSize is the default max size of the TSO request batch. defaultMaxTSOBatchSize = 10000 - // retryInterval and maxRetryTimes are used to control the retry interval and max retry times. - retryInterval = 500 * time.Millisecond - maxRetryTimes = 6 + dispatchRetryDelay = 50 * time.Millisecond + dispatchRetryCount = 2 ) -// TSOClient is the client used to get timestamps. -type TSOClient interface { +// Client defines the interface of a TSO client. +type Client interface { // GetTS gets a timestamp from PD or TSO microservice. GetTS(ctx context.Context) (int64, int64, error) // GetTSAsync gets a timestamp from PD or TSO microservice, without block the caller. @@ -68,7 +68,8 @@ type TSOClient interface { GetLocalTSAsync(ctx context.Context, _ string) TSFuture } -type tsoClient struct { +// Cli is the implementation of the TSO client. +type Cli struct { ctx context.Context cancel context.CancelFunc wg sync.WaitGroup @@ -85,13 +86,13 @@ type tsoClient struct { dispatcher atomic.Pointer[tsoDispatcher] } -// newTSOClient returns a new TSO client. -func newTSOClient( +// NewClient returns a new TSO client. +func NewClient( ctx context.Context, option *opt.Option, svcDiscovery sd.ServiceDiscovery, factory tsoStreamBuilderFactory, -) *tsoClient { +) *Cli { ctx, cancel := context.WithCancel(ctx) - c := &tsoClient{ + c := &Cli{ ctx: ctx, cancel: cancel, option: option, @@ -99,7 +100,7 @@ func newTSOClient( tsoStreamBuilderFactory: factory, tsoReqPool: &sync.Pool{ New: func() any { - return &tsoRequest{ + return &Request{ done: make(chan error, 1), physical: 0, logical: 0, @@ -115,23 +116,29 @@ func newTSOClient( return c } -func (c *tsoClient) getOption() *opt.Option { return c.option } +func (c *Cli) getOption() *opt.Option { return c.option } -func (c *tsoClient) getServiceDiscovery() sd.ServiceDiscovery { return c.svcDiscovery } +func (c *Cli) getServiceDiscovery() sd.ServiceDiscovery { return c.svcDiscovery } -func (c *tsoClient) getDispatcher() *tsoDispatcher { +func (c *Cli) getDispatcher() *tsoDispatcher { return c.dispatcher.Load() } -func (c *tsoClient) setup() { +// GetRequestPool gets the request pool of the TSO client. +func (c *Cli) GetRequestPool() *sync.Pool { + return c.tsoReqPool +} + +// Setup initializes the TSO client. +func (c *Cli) Setup() { if err := c.svcDiscovery.CheckMemberChanged(); err != nil { log.Warn("[tso] failed to check member changed", errs.ZapError(err)) } c.tryCreateTSODispatcher() } -// close closes the TSO client -func (c *tsoClient) close() { +// Close closes the TSO client +func (c *Cli) Close() { if c == nil { return } @@ -146,12 +153,13 @@ func (c *tsoClient) close() { } // scheduleUpdateTSOConnectionCtxs update the TSO connection contexts. -func (c *tsoClient) scheduleUpdateTSOConnectionCtxs() { +func (c *Cli) scheduleUpdateTSOConnectionCtxs() { c.getDispatcher().scheduleUpdateConnectionCtxs() } -func (c *tsoClient) getTSORequest(ctx context.Context) *tsoRequest { - req := c.tsoReqPool.Get().(*tsoRequest) +// GetTSORequest gets a TSO request from the pool. +func (c *Cli) GetTSORequest(ctx context.Context) *Request { + req := c.tsoReqPool.Get().(*Request) // Set needed fields in the request before using it. req.start = time.Now() req.pool = c.tsoReqPool @@ -163,7 +171,7 @@ func (c *tsoClient) getTSORequest(ctx context.Context) *tsoRequest { return req } -func (c *tsoClient) getLeaderURL() string { +func (c *Cli) getLeaderURL() string { url := c.leaderURL.Load() if url == nil { return "" @@ -172,7 +180,7 @@ func (c *tsoClient) getLeaderURL() string { } // getTSOLeaderClientConn returns the TSO leader gRPC client connection. -func (c *tsoClient) getTSOLeaderClientConn() (*grpc.ClientConn, string) { +func (c *Cli) getTSOLeaderClientConn() (*grpc.ClientConn, string) { url := c.getLeaderURL() if len(url) == 0 { log.Fatal("[tso] the tso leader should exist") @@ -184,7 +192,7 @@ func (c *tsoClient) getTSOLeaderClientConn() (*grpc.ClientConn, string) { return cc.(*grpc.ClientConn), url } -func (c *tsoClient) updateTSOLeaderURL(url string) error { +func (c *Cli) updateTSOLeaderURL(url string) error { c.leaderURL.Store(url) log.Info("[tso] switch the tso leader serving url", zap.String("new-url", url)) // Try to create the TSO dispatcher if it is not created yet. @@ -197,7 +205,7 @@ func (c *tsoClient) updateTSOLeaderURL(url string) error { // backupClientConn gets a grpc client connection of the current reachable and healthy // backup service endpoints randomly. Backup service endpoints are followers in a // quorum-based cluster or secondaries in a primary/secondary configured cluster. -func (c *tsoClient) backupClientConn() (*grpc.ClientConn, string) { +func (c *Cli) backupClientConn() (*grpc.ClientConn, string) { urls := c.svcDiscovery.GetBackupURLs() if len(urls) < 1 { return nil, "" @@ -233,7 +241,7 @@ type tsoConnectionContext struct { // updateConnectionCtxs will choose the proper way to update the connections. // It will return a bool to indicate whether the update is successful. -func (c *tsoClient) updateConnectionCtxs(ctx context.Context, connectionCtxs *sync.Map) bool { +func (c *Cli) updateConnectionCtxs(ctx context.Context, connectionCtxs *sync.Map) bool { // Normal connection creating, it will be affected by the `enableForwarding`. createTSOConnection := c.tryConnectToTSO if c.option.GetEnableTSOFollowerProxy() { @@ -250,7 +258,7 @@ func (c *tsoClient) updateConnectionCtxs(ctx context.Context, connectionCtxs *sy // and enableForwarding is true, it will create a new connection to a follower to do the forwarding, // while a new daemon will be created also to switch back to a normal leader connection ASAP the // connection comes back to normal. -func (c *tsoClient) tryConnectToTSO( +func (c *Cli) tryConnectToTSO( ctx context.Context, connectionCtxs *sync.Map, ) error { @@ -276,10 +284,10 @@ func (c *tsoClient) tryConnectToTSO( } ) - ticker := time.NewTicker(retryInterval) + ticker := time.NewTicker(constants.RetryInterval) defer ticker.Stop() // Retry several times before falling back to the follower when the network problem happens - for range maxRetryTimes { + for range constants.MaxRetryTimes { c.svcDiscovery.ScheduleCheckMemberChanged() cc, url = c.getTSOLeaderClientConn() if _, ok := connectionCtxs.Load(url); ok { @@ -320,7 +328,7 @@ func (c *tsoClient) tryConnectToTSO( } } - if networkErrNum == maxRetryTimes { + if networkErrNum == constants.MaxRetryTimes { // encounter the network error backupClientConn, backupURL := c.backupClientConn() if backupClientConn != nil { @@ -349,7 +357,7 @@ func (c *tsoClient) tryConnectToTSO( return err } -func (c *tsoClient) checkLeader( +func (c *Cli) checkLeader( ctx context.Context, forwardCancel context.CancelFunc, forwardedHostTrim, addr, url string, @@ -403,7 +411,7 @@ func (c *tsoClient) checkLeader( // tryConnectToTSOWithProxy will create multiple streams to all the service endpoints to work as // a TSO proxy to reduce the pressure of the main serving service endpoint. -func (c *tsoClient) tryConnectToTSOWithProxy( +func (c *Cli) tryConnectToTSOWithProxy( ctx context.Context, connectionCtxs *sync.Map, ) error { @@ -458,7 +466,7 @@ func (c *tsoClient) tryConnectToTSOWithProxy( // getAllTSOStreamBuilders returns a TSO stream builder for every service endpoint of TSO leader/followers // or of keyspace group primary/secondaries. -func (c *tsoClient) getAllTSOStreamBuilders() map[string]tsoStreamBuilder { +func (c *Cli) getAllTSOStreamBuilders() map[string]tsoStreamBuilder { var ( addrs = c.svcDiscovery.GetServiceURLs() streamBuilders = make(map[string]tsoStreamBuilder, len(addrs)) @@ -483,7 +491,7 @@ func (c *tsoClient) getAllTSOStreamBuilders() map[string]tsoStreamBuilder { } // tryCreateTSODispatcher will try to create the TSO dispatcher if it is not created yet. -func (c *tsoClient) tryCreateTSODispatcher() { +func (c *Cli) tryCreateTSODispatcher() { // The dispatcher is already created. if c.getDispatcher() != nil { return @@ -502,8 +510,8 @@ func (c *tsoClient) tryCreateTSODispatcher() { } } -// dispatchRequest will send the TSO request to the corresponding TSO dispatcher. -func (c *tsoClient) dispatchRequest(request *tsoRequest) (bool, error) { +// DispatchRequest will send the TSO request to the corresponding TSO dispatcher. +func (c *Cli) DispatchRequest(request *Request) (bool, error) { if c.getDispatcher() == nil { err := errs.ErrClientGetTSO.FastGenByArgs("tso dispatcher is not ready") log.Error("[tso] dispatch tso request error", errs.ZapError(err)) diff --git a/client/tso_dispatcher.go b/client/clients/tso/tso_dispatcher.go similarity index 95% rename from client/tso_dispatcher.go rename to client/clients/tso/tso_dispatcher.go index fd9a17405e6..2ec59b3b6fc 100644 --- a/client/tso_dispatcher.go +++ b/client/clients/tso/tso_dispatcher.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pd +package tso import ( "context" @@ -28,6 +28,8 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/log" + "github.com/tikv/pd/client/batch" + "github.com/tikv/pd/client/constants" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" @@ -84,7 +86,7 @@ type tsoDispatcher struct { provider tsoServiceProvider // URL -> *connectionContext connectionCtxs *sync.Map - tsoRequestCh chan *tsoRequest + tsoRequestCh chan *Request tsDeadlineCh chan *deadline latestTSOInfo atomic.Pointer[tsoInfo] // For reusing `*batchController` objects @@ -108,9 +110,9 @@ func newTSODispatcher( provider tsoServiceProvider, ) *tsoDispatcher { dispatcherCtx, dispatcherCancel := context.WithCancel(ctx) - tsoRequestCh := make(chan *tsoRequest, maxBatchSize*2) + tsoRequestCh := make(chan *Request, maxBatchSize*2) failpoint.Inject("shortDispatcherChannel", func() { - tsoRequestCh = make(chan *tsoRequest, 1) + tsoRequestCh = make(chan *Request, 1) }) // A large-enough capacity to hold maximum concurrent RPC requests. In our design, the concurrency is at most 16. @@ -126,7 +128,7 @@ func newTSODispatcher( tsDeadlineCh: make(chan *deadline, tokenChCapacity), batchBufferPool: &sync.Pool{ New: func() any { - return newBatchController[*tsoRequest]( + return batch.NewController[*Request]( maxBatchSize*2, tsoRequestFinisher(0, 0, invalidStreamID), metrics.TSOBestBatchSize, @@ -174,7 +176,7 @@ func (td *tsoDispatcher) scheduleUpdateConnectionCtxs() { func (td *tsoDispatcher) revokePendingRequests(err error) { for range len(td.tsoRequestCh) { req := <-td.tsoRequestCh - req.tryDone(err) + req.TryDone(err) } } @@ -184,7 +186,7 @@ func (td *tsoDispatcher) close() { td.revokePendingRequests(tsoErr) } -func (td *tsoDispatcher) push(request *tsoRequest) { +func (td *tsoDispatcher) push(request *Request) { td.tsoRequestCh <- request } @@ -195,7 +197,7 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { svcDiscovery = provider.getServiceDiscovery() option = provider.getOption() connectionCtxs = td.connectionCtxs - tsoBatchController *batchController[*tsoRequest] + tsoBatchController *batch.Controller[*Request] ) log.Info("[tso] tso dispatcher created") @@ -207,7 +209,7 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { cc.(*tsoConnectionContext).cancel() return true }) - if tsoBatchController != nil && tsoBatchController.collectedRequestCount != 0 { + if tsoBatchController != nil && tsoBatchController.GetCollectedRequestCount() != 0 { // If you encounter this failure, please check the stack in the logs to see if it's a panic. log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop", zap.Any("panic", recover())) } @@ -245,7 +247,7 @@ tsoBatchLoop: // In case error happens, the loop may continue without resetting `tsoBatchController` for retrying. if tsoBatchController == nil { - tsoBatchController = td.batchBufferPool.Get().(*batchController[*tsoRequest]) + tsoBatchController = td.batchBufferPool.Get().(*batch.Controller[*Request]) } maxBatchWaitInterval := option.GetMaxTSOBatchWaitInterval() @@ -262,7 +264,7 @@ tsoBatchLoop: // Start to collect the TSO requests. // Once the TSO requests are collected, must make sure they could be finished or revoked eventually, // otherwise the upper caller may get blocked on waiting for the results. - if err = tsoBatchController.fetchPendingRequests(ctx, td.tsoRequestCh, td.tokenCh, maxBatchWaitInterval); err != nil { + if err = tsoBatchController.FetchPendingRequests(ctx, td.tsoRequestCh, td.tokenCh, maxBatchWaitInterval); err != nil { if err == context.Canceled { log.Info("[tso] stop fetching the pending tso requests due to context canceled") } else { @@ -272,7 +274,7 @@ tsoBatchLoop: return } if maxBatchWaitInterval >= 0 { - tsoBatchController.adjustBestBatchSize() + tsoBatchController.AdjustBestBatchSize() } // Stop the timer if it's not stopped. if !streamLoopTimer.Stop() { @@ -297,7 +299,7 @@ tsoBatchLoop: if provider.updateConnectionCtxs(ctx, connectionCtxs) { continue streamChoosingLoop } - timer := time.NewTimer(retryInterval) + timer := time.NewTimer(constants.RetryInterval) select { case <-ctx.Done(): // Finish the collected requests if the context is canceled. @@ -381,7 +383,7 @@ tsoBatchLoop: } batchingTimer.Reset(remainingBatchTime) - err = tsoBatchController.fetchRequestsWithTimer(ctx, td.tsoRequestCh, batchingTimer) + err = tsoBatchController.FetchRequestsWithTimer(ctx, td.tsoRequestCh, batchingTimer) if err != nil { // There should not be other kinds of errors. log.Info("[tso] stop fetching the pending tso requests due to context canceled", @@ -531,11 +533,11 @@ func chooseStream(connectionCtxs *sync.Map) (connectionCtx *tsoConnectionContext // `close(done)` will be called at the same time when finishing the requests. // If this function returns a non-nil error, the requests will always be canceled synchronously. func (td *tsoDispatcher) processRequests( - stream *tsoStream, tbc *batchController[*tsoRequest], done chan struct{}, + stream *tsoStream, tbc *batch.Controller[*Request], done chan struct{}, ) error { // `done` must be guaranteed to be eventually called. var ( - requests = tbc.getCollectedRequests() + requests = tbc.GetCollectedRequests() traceRegions = make([]*trace.Region, 0, len(requests)) spans = make([]opentracing.Span, 0, len(requests)) ) @@ -594,7 +596,7 @@ func (td *tsoDispatcher) processRequests( err := stream.processRequests( clusterID, keyspaceID, reqKeyspaceGroupID, - count, tbc.extraBatchingStartTime, cb) + count, tbc.GetExtraBatchingStartTime(), cb) if err != nil { close(done) @@ -604,25 +606,25 @@ func (td *tsoDispatcher) processRequests( return nil } -func tsoRequestFinisher(physical, firstLogical int64, streamID string) finisherFunc[*tsoRequest] { - return func(idx int, tsoReq *tsoRequest, err error) { +func tsoRequestFinisher(physical, firstLogical int64, streamID string) batch.FinisherFunc[*Request] { + return func(idx int, tsoReq *Request, err error) { // Retrieve the request context before the request is done to trace without race. requestCtx := tsoReq.requestCtx tsoReq.physical, tsoReq.logical = physical, firstLogical+int64(idx) tsoReq.streamID = streamID - tsoReq.tryDone(err) + tsoReq.TryDone(err) trace.StartRegion(requestCtx, "pdclient.tsoReqDequeue").End() } } -func (td *tsoDispatcher) cancelCollectedRequests(tbc *batchController[*tsoRequest], streamID string, err error) { +func (td *tsoDispatcher) cancelCollectedRequests(tbc *batch.Controller[*Request], streamID string, err error) { td.tokenCh <- struct{}{} - tbc.finishCollectedRequests(tsoRequestFinisher(0, 0, streamID), err) + tbc.FinishCollectedRequests(tsoRequestFinisher(0, 0, streamID), err) } -func (td *tsoDispatcher) doneCollectedRequests(tbc *batchController[*tsoRequest], physical, firstLogical int64, streamID string) { +func (td *tsoDispatcher) doneCollectedRequests(tbc *batch.Controller[*Request], physical, firstLogical int64, streamID string) { td.tokenCh <- struct{}{} - tbc.finishCollectedRequests(tsoRequestFinisher(physical, firstLogical, streamID), nil) + tbc.FinishCollectedRequests(tsoRequestFinisher(physical, firstLogical, streamID), nil) } // checkMonotonicity checks whether the monotonicity of the TSO allocation is violated. diff --git a/client/tso_dispatcher_test.go b/client/clients/tso/tso_dispatcher_test.go similarity index 83% rename from client/tso_dispatcher_test.go rename to client/clients/tso/tso_dispatcher_test.go index 6cb963df3df..2b5fd1e52e8 100644 --- a/client/tso_dispatcher_test.go +++ b/client/clients/tso/tso_dispatcher_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pd +package tso import ( "context" @@ -107,7 +107,7 @@ func (s *testTSODispatcherSuite) SetupTest() { s.dispatcher = newTSODispatcher(context.Background(), defaultMaxTSOBatchSize, newMockTSOServiceProvider(s.option, createStream)) s.reqPool = &sync.Pool{ New: func() any { - return &tsoRequest{ + return &Request{ done: make(chan error, 1), physical: 0, logical: 0, @@ -146,8 +146,8 @@ func (s *testTSODispatcherSuite) TearDownTest() { s.reqPool = nil } -func (s *testTSODispatcherSuite) getReq(ctx context.Context) *tsoRequest { - req := s.reqPool.Get().(*tsoRequest) +func (s *testTSODispatcherSuite) getReq(ctx context.Context) *Request { + req := s.reqPool.Get().(*Request) req.clientCtx = context.Background() req.requestCtx = ctx req.physical = 0 @@ -157,19 +157,19 @@ func (s *testTSODispatcherSuite) getReq(ctx context.Context) *tsoRequest { return req } -func (s *testTSODispatcherSuite) sendReq(ctx context.Context) *tsoRequest { +func (s *testTSODispatcherSuite) sendReq(ctx context.Context) *Request { req := s.getReq(ctx) s.dispatcher.push(req) return req } -func (s *testTSODispatcherSuite) reqMustNotReady(req *tsoRequest) { +func (s *testTSODispatcherSuite) reqMustNotReady(req *Request) { _, _, err := req.waitTimeout(time.Millisecond * 50) s.re.Error(err) s.re.ErrorIs(err, context.DeadlineExceeded) } -func (s *testTSODispatcherSuite) reqMustReady(req *tsoRequest) (physical int64, logical int64) { +func (s *testTSODispatcherSuite) reqMustReady(req *Request) (physical int64, logical int64) { physical, logical, err := req.waitTimeout(time.Second) s.re.NoError(err) return physical, logical @@ -230,7 +230,7 @@ func (s *testTSODispatcherSuite) testStaticConcurrencyImpl(concurrency int) { // way. And as `reqMustNotReady` delays for a while, requests shouldn't be batched as long as there are free tokens. // The first N requests (N=tokenCount) will each be a single batch, occupying a token. The last 3 are blocked, // and will be batched together once there is a free token. - reqs := make([]*tsoRequest, 0, tokenCount+3) + reqs := make([]*Request, 0, tokenCount+3) for range tokenCount + 3 { req := s.sendReq(ctx) @@ -272,11 +272,11 @@ func (s *testTSODispatcherSuite) testStaticConcurrencyImpl(concurrency int) { } func (s *testTSODispatcherSuite) TestConcurrentRPC() { - s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeNoDelay", "return")) - s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoDispatcherAlwaysCheckConcurrency", "return")) + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/clients/tso/tsoDispatcherConcurrentModeNoDelay", "return")) + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/clients/tso/tsoDispatcherAlwaysCheckConcurrency", "return")) defer func() { - s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeNoDelay")) - s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/tsoDispatcherAlwaysCheckConcurrency")) + s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/clients/tso/tsoDispatcherConcurrentModeNoDelay")) + s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/clients/tso/tsoDispatcherAlwaysCheckConcurrency")) }() s.testStaticConcurrencyImpl(1) @@ -289,11 +289,11 @@ func (s *testTSODispatcherSuite) TestBatchDelaying() { ctx := context.Background() s.option.SetTSOClientRPCConcurrency(2) - s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeNoDelay", "return")) - s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoStreamSimulateEstimatedRPCLatency", `return("12ms")`)) + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/clients/tso/tsoDispatcherConcurrentModeNoDelay", "return")) + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/clients/tso/tsoStreamSimulateEstimatedRPCLatency", `return("12ms")`)) defer func() { - s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeNoDelay")) - s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/tsoStreamSimulateEstimatedRPCLatency")) + s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/clients/tso/tsoDispatcherConcurrentModeNoDelay")) + s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/clients/tso/tsoStreamSimulateEstimatedRPCLatency")) }() // Make sure concurrency option takes effect. @@ -302,9 +302,9 @@ func (s *testTSODispatcherSuite) TestBatchDelaying() { s.reqMustReady(req) // Trigger the check. - s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeAssertDelayDuration", `return("6ms")`)) + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/clients/tso/tsoDispatcherConcurrentModeAssertDelayDuration", `return("6ms")`)) defer func() { - s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeAssertDelayDuration")) + s.re.NoError(failpoint.Disable("github.com/tikv/pd/client/clients/tso/tsoDispatcherConcurrentModeAssertDelayDuration")) }() req = s.sendReq(ctx) s.streamInner.generateNext() @@ -312,13 +312,13 @@ func (s *testTSODispatcherSuite) TestBatchDelaying() { // Try other concurrency. s.option.SetTSOClientRPCConcurrency(3) - s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeAssertDelayDuration", `return("4ms")`)) + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/clients/tso/tsoDispatcherConcurrentModeAssertDelayDuration", `return("4ms")`)) req = s.sendReq(ctx) s.streamInner.generateNext() s.reqMustReady(req) s.option.SetTSOClientRPCConcurrency(4) - s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/tsoDispatcherConcurrentModeAssertDelayDuration", `return("3ms")`)) + s.re.NoError(failpoint.Enable("github.com/tikv/pd/client/clients/tso/tsoDispatcherConcurrentModeAssertDelayDuration", `return("3ms")`)) req = s.sendReq(ctx) s.streamInner.generateNext() s.reqMustReady(req) @@ -331,15 +331,15 @@ func BenchmarkTSODispatcherHandleRequests(b *testing.B) { reqPool := &sync.Pool{ New: func() any { - return &tsoRequest{ + return &Request{ done: make(chan error, 1), physical: 0, logical: 0, } }, } - getReq := func() *tsoRequest { - req := reqPool.Get().(*tsoRequest) + getReq := func() *Request { + req := reqPool.Get().(*Request) req.clientCtx = ctx req.requestCtx = ctx req.physical = 0 diff --git a/client/tso_request.go b/client/clients/tso/tso_request.go similarity index 78% rename from client/tso_request.go rename to client/clients/tso/tso_request.go index d2048e4b3b1..0c9f54f8b2b 100644 --- a/client/tso_request.go +++ b/client/clients/tso/tso_request.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pd +package tso import ( "context" @@ -31,11 +31,12 @@ type TSFuture interface { } var ( - _ TSFuture = (*tsoRequest)(nil) + _ TSFuture = (*Request)(nil) _ TSFuture = (*tsoRequestFastFail)(nil) ) -type tsoRequest struct { +// Request is a TSO request. +type Request struct { requestCtx context.Context clientCtx context.Context done chan error @@ -50,8 +51,16 @@ type tsoRequest struct { pool *sync.Pool } -// tryDone tries to send the result to the channel, it will not block. -func (req *tsoRequest) tryDone(err error) { +// IsFrom checks if the request is from the specified pool. +func (req *Request) IsFrom(pool *sync.Pool) bool { + if req == nil { + return false + } + return req.pool == pool +} + +// TryDone tries to send the result to the channel, it will not block. +func (req *Request) TryDone(err error) { select { case req.done <- err: default: @@ -59,12 +68,12 @@ func (req *tsoRequest) tryDone(err error) { } // Wait will block until the TSO result is ready. -func (req *tsoRequest) Wait() (physical int64, logical int64, err error) { +func (req *Request) Wait() (physical int64, logical int64, err error) { return req.waitCtx(req.requestCtx) } // waitCtx waits for the TSO result with specified ctx, while not using req.requestCtx. -func (req *tsoRequest) waitCtx(ctx context.Context) (physical int64, logical int64, err error) { +func (req *Request) waitCtx(ctx context.Context) (physical int64, logical int64, err error) { // If tso command duration is observed very high, the reason could be it // takes too long for Wait() be called. start := time.Now() @@ -92,7 +101,7 @@ func (req *tsoRequest) waitCtx(ctx context.Context) (physical int64, logical int } // waitTimeout waits for the TSO result for limited time. Currently only for test purposes. -func (req *tsoRequest) waitTimeout(timeout time.Duration) (physical int64, logical int64, err error) { +func (req *Request) waitTimeout(timeout time.Duration) (physical int64, logical int64, err error) { ctx, cancel := context.WithTimeout(req.requestCtx, timeout) defer cancel() return req.waitCtx(ctx) @@ -102,7 +111,8 @@ type tsoRequestFastFail struct { err error } -func newTSORequestFastFail(err error) *tsoRequestFastFail { +// NewRequestFastFail creates a new fast fail TSO request. +func NewRequestFastFail(err error) *tsoRequestFastFail { return &tsoRequestFastFail{err} } diff --git a/client/clients/tso/tso_request_test.go b/client/clients/tso/tso_request_test.go new file mode 100644 index 00000000000..6887ee28124 --- /dev/null +++ b/client/clients/tso/tso_request_test.go @@ -0,0 +1,50 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tso + +import ( + "context" + "testing" + + "github.com/pingcap/errors" + "github.com/stretchr/testify/require" +) + +func TestTsoRequestWait(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + req := &Request{ + done: make(chan error, 1), + physical: 0, + logical: 0, + requestCtx: context.TODO(), + clientCtx: ctx, + } + cancel() + _, _, err := req.Wait() + re.ErrorIs(errors.Cause(err), context.Canceled) + + ctx, cancel = context.WithCancel(context.Background()) + req = &Request{ + done: make(chan error, 1), + physical: 0, + logical: 0, + requestCtx: ctx, + clientCtx: context.TODO(), + } + cancel() + _, _, err = req.Wait() + re.ErrorIs(errors.Cause(err), context.Canceled) +} diff --git a/client/tso_stream.go b/client/clients/tso/tso_stream.go similarity index 95% rename from client/tso_stream.go rename to client/clients/tso/tso_stream.go index ce3c513ac46..6baf63c8882 100644 --- a/client/tso_stream.go +++ b/client/clients/tso/tso_stream.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pd +package tso import ( "context" @@ -42,16 +42,18 @@ type tsoStreamBuilderFactory interface { makeBuilder(cc *grpc.ClientConn) tsoStreamBuilder } -type pdTSOStreamBuilderFactory struct{} +// PDStreamBuilderFactory is a factory for building TSO streams to the PD cluster. +type PDStreamBuilderFactory struct{} -func (*pdTSOStreamBuilderFactory) makeBuilder(cc *grpc.ClientConn) tsoStreamBuilder { - return &pdTSOStreamBuilder{client: pdpb.NewPDClient(cc), serverURL: cc.Target()} +func (*PDStreamBuilderFactory) makeBuilder(cc *grpc.ClientConn) tsoStreamBuilder { + return &pdStreamBuilder{client: pdpb.NewPDClient(cc), serverURL: cc.Target()} } -type tsoTSOStreamBuilderFactory struct{} +// MSStreamBuilderFactory is a factory for building TSO streams to the microservice cluster. +type MSStreamBuilderFactory struct{} -func (*tsoTSOStreamBuilderFactory) makeBuilder(cc *grpc.ClientConn) tsoStreamBuilder { - return &tsoTSOStreamBuilder{client: tsopb.NewTSOClient(cc), serverURL: cc.Target()} +func (*MSStreamBuilderFactory) makeBuilder(cc *grpc.ClientConn) tsoStreamBuilder { + return &msStreamBuilder{client: tsopb.NewTSOClient(cc), serverURL: cc.Target()} } // TSO Stream Builder @@ -60,12 +62,12 @@ type tsoStreamBuilder interface { build(context.Context, context.CancelFunc, time.Duration) (*tsoStream, error) } -type pdTSOStreamBuilder struct { +type pdStreamBuilder struct { serverURL string client pdpb.PDClient } -func (b *pdTSOStreamBuilder) build(ctx context.Context, cancel context.CancelFunc, timeout time.Duration) (*tsoStream, error) { +func (b *pdStreamBuilder) build(ctx context.Context, cancel context.CancelFunc, timeout time.Duration) (*tsoStream, error) { done := make(chan struct{}) // TODO: we need to handle a conner case that this goroutine is timeout while the stream is successfully created. go checkStreamTimeout(ctx, cancel, done, timeout) @@ -77,12 +79,12 @@ func (b *pdTSOStreamBuilder) build(ctx context.Context, cancel context.CancelFun return nil, err } -type tsoTSOStreamBuilder struct { +type msStreamBuilder struct { serverURL string client tsopb.TSOClient } -func (b *tsoTSOStreamBuilder) build( +func (b *msStreamBuilder) build( ctx context.Context, cancel context.CancelFunc, timeout time.Duration, ) (*tsoStream, error) { done := make(chan struct{}) diff --git a/client/tso_stream_test.go b/client/clients/tso/tso_stream_test.go similarity index 99% rename from client/tso_stream_test.go rename to client/clients/tso/tso_stream_test.go index a842befb550..0244c06e024 100644 --- a/client/tso_stream_test.go +++ b/client/clients/tso/tso_stream_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pd +package tso import ( "context" diff --git a/client/constants/constants.go b/client/constants/constants.go index 10963dd10b6..7c8a5a751ae 100644 --- a/client/constants/constants.go +++ b/client/constants/constants.go @@ -14,6 +14,8 @@ package constants +import "time" + const ( // DefaultKeyspaceID is the default keyspace ID. // Valid keyspace id range is [0, 0xFFFFFF](uint24max, or 16777215) @@ -29,4 +31,9 @@ const ( DefaultKeyspaceGroupID = uint32(0) // DefaultKeyspaceName is the default keyspace name. DefaultKeyspaceName = "DEFAULT" + + // RetryInterval is the base retry interval. + RetryInterval = 500 * time.Millisecond + // MaxRetryTimes is the max retry times. + MaxRetryTimes = 6 ) diff --git a/client/inner_client.go b/client/inner_client.go index 467d6b66352..7be35e9a3b9 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -9,6 +9,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + "github.com/tikv/pd/client/clients/tso" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" @@ -83,21 +84,21 @@ func (c *innerClient) setServiceMode(newMode pdpb.ServiceMode) { func (c *innerClient) resetTSOClientLocked(mode pdpb.ServiceMode) { // Re-create a new TSO client. var ( - newTSOCli *tsoClient + newTSOCli *tso.Cli newTSOSvcDiscovery sd.ServiceDiscovery ) switch mode { case pdpb.ServiceMode_PD_SVC_MODE: - newTSOCli = newTSOClient(c.ctx, c.option, - c.pdSvcDiscovery, &pdTSOStreamBuilderFactory{}) + newTSOCli = tso.NewClient(c.ctx, c.option, + c.pdSvcDiscovery, &tso.PDStreamBuilderFactory{}) case pdpb.ServiceMode_API_SVC_MODE: newTSOSvcDiscovery = sd.NewTSOServiceDiscovery( c.ctx, c, c.pdSvcDiscovery, c.keyspaceID, c.tlsCfg, c.option) // At this point, the keyspace group isn't known yet. Starts from the default keyspace group, // and will be updated later. - newTSOCli = newTSOClient(c.ctx, c.option, - newTSOSvcDiscovery, &tsoTSOStreamBuilderFactory{}) + newTSOCli = tso.NewClient(c.ctx, c.option, + newTSOSvcDiscovery, &tso.MSStreamBuilderFactory{}) if err := newTSOSvcDiscovery.Init(); err != nil { log.Error("[pd] failed to initialize tso service discovery. keep the current service mode", zap.Strings("svr-urls", c.svrUrls), @@ -109,11 +110,11 @@ func (c *innerClient) resetTSOClientLocked(mode pdpb.ServiceMode) { log.Warn("[pd] intend to switch to unknown service mode, just return") return } - newTSOCli.setup() + newTSOCli.Setup() // Replace the old TSO client. oldTSOClient := c.tsoClient c.tsoClient = newTSOCli - oldTSOClient.close() + oldTSOClient.Close() // Replace the old TSO service discovery if needed. oldTSOSvcDiscovery := c.tsoSvcDiscovery // If newTSOSvcDiscovery is nil, that's expected, as it means we are switching to PD service mode and @@ -139,7 +140,7 @@ func (c *innerClient) getServiceMode() pdpb.ServiceMode { return c.serviceMode } -func (c *innerClient) getTSOClient() *tsoClient { +func (c *innerClient) getTSOClient() *tso.Cli { c.RLock() defer c.RUnlock() return c.tsoClient @@ -210,11 +211,11 @@ func (c *innerClient) getOrCreateGRPCConn() (*grpc.ClientConn, error) { return cc, err } -func (c *innerClient) dispatchTSORequestWithRetry(ctx context.Context) TSFuture { +func (c *innerClient) dispatchTSORequestWithRetry(ctx context.Context) tso.TSFuture { var ( retryable bool err error - req *tsoRequest + req *tso.Request ) for i := range dispatchRetryCount { // Do not delay for the first time. @@ -227,20 +228,20 @@ func (c *innerClient) dispatchTSORequestWithRetry(ctx context.Context) TSFuture err = errs.ErrClientGetTSO.FastGenByArgs("tso client is nil") continue } - // Get a new request from the pool if it's nil or not from the current pool. - if req == nil || req.pool != tsoClient.tsoReqPool { - req = tsoClient.getTSORequest(ctx) + // Get a new request from the pool if it's not from the current pool. + if !req.IsFrom(tsoClient.GetRequestPool()) { + req = tsoClient.GetTSORequest(ctx) } - retryable, err = tsoClient.dispatchRequest(req) + retryable, err = tsoClient.DispatchRequest(req) if !retryable { break } } if err != nil { if req == nil { - return newTSORequestFastFail(err) + return tso.NewRequestFastFail(err) } - req.tryDone(err) + req.TryDone(err) } return req } diff --git a/client/resource_manager_client.go b/client/resource_manager_client.go index 513f1a1d170..3cf2970109f 100644 --- a/client/resource_manager_client.go +++ b/client/resource_manager_client.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/kvproto/pkg/meta_storagepb" rmpb "github.com/pingcap/kvproto/pkg/resource_manager" "github.com/pingcap/log" + "github.com/tikv/pd/client/constants" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/opt" "go.uber.org/zap" @@ -374,9 +375,9 @@ func (c *innerClient) tryResourceManagerConnect(ctx context.Context, connection err error stream rmpb.ResourceManager_AcquireTokenBucketsClient ) - ticker := time.NewTicker(retryInterval) + ticker := time.NewTicker(constants.RetryInterval) defer ticker.Stop() - for range maxRetryTimes { + for range constants.MaxRetryTimes { cc, err := c.resourceManagerClient() if err != nil { continue diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index 79f981f3bb3..4a86a2402a1 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -404,10 +404,10 @@ func TestUnavailableTimeAfterLeaderIsReady(t *testing.T) { go func() { defer wg.Done() leader := cluster.GetLeaderServer() - re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/clients/tso/unreachableNetwork", "return(true)")) leader.Stop() re.NotEmpty(cluster.WaitLeader()) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/clients/tso/unreachableNetwork")) leaderReadyTime = time.Now() }() wg.Wait() @@ -519,7 +519,7 @@ func (suite *followerForwardAndHandleTestSuite) TestGetTsoByFollowerForwarding1( cli := setupCli(ctx, re, suite.endpoints, opt.WithForwardingOption(true)) defer cli.Close() - re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/clients/tso/unreachableNetwork", "return(true)")) var lastTS uint64 testutil.Eventually(re, func() bool { physical, logical, err := cli.GetTS(context.TODO()) @@ -532,7 +532,7 @@ func (suite *followerForwardAndHandleTestSuite) TestGetTsoByFollowerForwarding1( }) lastTS = checkTS(re, cli, lastTS) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/clients/tso/unreachableNetwork")) time.Sleep(2 * time.Second) checkTS(re, cli, lastTS) @@ -554,7 +554,7 @@ func (suite *followerForwardAndHandleTestSuite) TestGetTsoByFollowerForwarding2( cli := setupCli(ctx, re, suite.endpoints, opt.WithForwardingOption(true)) defer cli.Close() - re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/clients/tso/unreachableNetwork", "return(true)")) var lastTS uint64 testutil.Eventually(re, func() bool { physical, logical, err := cli.GetTS(context.TODO()) @@ -571,7 +571,7 @@ func (suite *followerForwardAndHandleTestSuite) TestGetTsoByFollowerForwarding2( re.NotEmpty(suite.cluster.WaitLeader()) lastTS = checkTS(re, cli, lastTS) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/clients/tso/unreachableNetwork")) time.Sleep(5 * time.Second) checkTS(re, cli, lastTS) } @@ -783,7 +783,7 @@ func (suite *followerForwardAndHandleTestSuite) TestGetTSFuture() { ctx, cancel := context.WithCancel(suite.ctx) defer cancel() - re.NoError(failpoint.Enable("github.com/tikv/pd/client/shortDispatcherChannel", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/clients/tso/shortDispatcherChannel", "return(true)")) cli := setupCli(ctx, re, suite.endpoints) @@ -820,7 +820,7 @@ func (suite *followerForwardAndHandleTestSuite) TestGetTSFuture() { wg2.Wait() wg3.Wait() re.Less(time.Since(start), time.Second*2) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/shortDispatcherChannel")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/clients/tso/shortDispatcherChannel")) } func checkTS(re *require.Assertions, cli pd.Client, lastTS uint64) uint64 { diff --git a/tests/integrations/tso/client_test.go b/tests/integrations/tso/client_test.go index 422d578326a..da2f6d9f5c9 100644 --- a/tests/integrations/tso/client_test.go +++ b/tests/integrations/tso/client_test.go @@ -30,6 +30,7 @@ import ( "github.com/stretchr/testify/suite" pd "github.com/tikv/pd/client" "github.com/tikv/pd/client/caller" + "github.com/tikv/pd/client/clients/tso" "github.com/tikv/pd/client/opt" sd "github.com/tikv/pd/client/servicediscovery" "github.com/tikv/pd/client/utils/testutil" @@ -241,7 +242,7 @@ func (suite *tsoClientTestSuite) TestGetTSAsync() { for _, client := range suite.clients { go func(client pd.Client) { defer wg.Done() - tsFutures := make([]pd.TSFuture, tsoRequestRound) + tsFutures := make([]tso.TSFuture, tsoRequestRound) for j := range tsFutures { tsFutures[j] = client.GetTSAsync(suite.ctx) } @@ -447,7 +448,7 @@ func (suite *tsoClientTestSuite) TestRandomShutdown() { func (suite *tsoClientTestSuite) TestGetTSWhileResettingTSOClient() { re := suite.Require() - re.NoError(failpoint.Enable("github.com/tikv/pd/client/delayDispatchTSORequest", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/clients/tso/delayDispatchTSORequest", "return(true)")) var ( stopSignal atomic.Bool wg sync.WaitGroup @@ -480,7 +481,7 @@ func (suite *tsoClientTestSuite) TestGetTSWhileResettingTSOClient() { } stopSignal.Store(true) wg.Wait() - re.NoError(failpoint.Disable("github.com/tikv/pd/client/delayDispatchTSORequest")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/clients/tso/delayDispatchTSORequest")) } // When we upgrade the PD cluster, there may be a period of time that the old and new PDs are running at the same time.