diff --git a/client/client.go b/client/client.go index eaebef7e10c1..1865fd0866ed 100644 --- a/client/client.go +++ b/client/client.go @@ -301,7 +301,7 @@ func (k *serviceModeKeeper) close() { fallthrough case pdpb.ServiceMode_PD_SVC_MODE: if k.tsoClient != nil { - k.tsoClient.Close() + k.tsoClient.close() } case pdpb.ServiceMode_UNKNOWN_SVC_MODE: } @@ -651,11 +651,11 @@ func (c *client) resetTSOClientLocked(mode pdpb.ServiceMode) { log.Warn("[pd] intend to switch to unknown service mode, just return") return } - newTSOCli.Setup() + newTSOCli.setup() // Replace the old TSO client. oldTSOClient := c.tsoClient c.tsoClient = newTSOCli - oldTSOClient.Close() + oldTSOClient.close() // Replace the old TSO service discovery if needed. oldTSOSvcDiscovery := c.tsoSvcDiscovery // If newTSOSvcDiscovery is nil, that's expected, as it means we are switching to PD service mode and diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index d7ba5d7e74bb..a713b7a187d8 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -139,9 +139,11 @@ func (tbc *tsoBatchController) adjustBestBatchSize() { func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical int64, suffixBits uint32, err error) { for i := 0; i < tbc.collectedRequestCount; i++ { tsoReq := tbc.collectedRequests[i] + // Retrieve the request context before the request is done to trace without race. + requestCtx := tsoReq.requestCtx tsoReq.physical, tsoReq.logical = physical, tsoutil.AddLogical(firstLogical, int64(i), suffixBits) - defer trace.StartRegion(tsoReq.requestCtx, "pdclient.tsoReqDequeue").End() // nolint tsoReq.tryDone(err) + trace.StartRegion(requestCtx, "pdclient.tsoReqDequeue").End() } // Prevent the finished requests from being processed again. tbc.collectedRequestCount = 0 diff --git a/client/tso_client.go b/client/tso_client.go index 347d1f6ec0ad..72b09d8054df 100644 --- a/client/tso_client.go +++ b/client/tso_client.go @@ -22,13 +22,11 @@ import ( "sync" "time" - "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/log" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/grpcutil" - "github.com/tikv/pd/client/tsoutil" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -36,6 +34,15 @@ import ( "google.golang.org/grpc/status" ) +const ( + tsoDispatcherCheckInterval = time.Minute + // defaultMaxTSOBatchSize is the default max size of the TSO request batch. + defaultMaxTSOBatchSize = 10000 + // retryInterval and maxRetryTimes are used to control the retry interval and max retry times. + retryInterval = 500 * time.Millisecond + maxRetryTimes = 6 +) + // TSOClient is the client used to get timestamps. type TSOClient interface { // GetTS gets a timestamp from PD or TSO microservice. @@ -70,14 +77,8 @@ type tsoClient struct { // tsoDispatcher is used to dispatch different TSO requests to // the corresponding dc-location TSO channel. tsoDispatcher sync.Map // Same as map[string]*tsoDispatcher - // dc-location -> deadline - tsDeadline sync.Map // Same as map[string]chan deadline - // dc-location -> *tsoInfo while the tsoInfo is the last TSO info - lastTSOInfoMap sync.Map // Same as map[string]*tsoInfo - - checkTSDeadlineCh chan struct{} - checkTSODispatcherCh chan struct{} - updateTSOConnectionCtxsCh chan struct{} + + checkTSODispatcherCh chan struct{} } // newTSOClient returns a new TSO client. @@ -101,49 +102,64 @@ func newTSOClient( } }, }, - checkTSDeadlineCh: make(chan struct{}), - checkTSODispatcherCh: make(chan struct{}, 1), - updateTSOConnectionCtxsCh: make(chan struct{}, 1), + checkTSODispatcherCh: make(chan struct{}, 1), } eventSrc := svcDiscovery.(tsoAllocatorEventSource) eventSrc.SetTSOLocalServURLsUpdatedCallback(c.updateTSOLocalServURLs) eventSrc.SetTSOGlobalServURLUpdatedCallback(c.updateTSOGlobalServURL) - c.svcDiscovery.AddServiceURLsSwitchedCallback(c.scheduleUpdateTSOConnectionCtxs) + c.svcDiscovery.AddServiceURLsSwitchedCallback(c.scheduleUpdateAllTSOConnectionCtxs) return c } -func (c *tsoClient) Setup() { +func (c *tsoClient) getOption() *option { return c.option } + +func (c *tsoClient) getServiceDiscovery() ServiceDiscovery { return c.svcDiscovery } + +func (c *tsoClient) setup() { c.svcDiscovery.CheckMemberChanged() c.updateTSODispatcher() // Start the daemons. - c.wg.Add(2) + c.wg.Add(1) go c.tsoDispatcherCheckLoop() - go c.tsCancelLoop() } -// Close closes the TSO client -func (c *tsoClient) Close() { +func (c *tsoClient) tsoDispatcherCheckLoop() { + log.Info("[tso] start tso dispatcher check loop") + defer log.Info("[tso] exit tso dispatcher check loop") + defer c.wg.Done() + + loopCtx, loopCancel := context.WithCancel(c.ctx) + defer loopCancel() + + ticker := time.NewTicker(tsoDispatcherCheckInterval) + defer ticker.Stop() + for { + c.updateTSODispatcher() + select { + case <-ticker.C: + case <-c.checkTSODispatcherCh: + case <-loopCtx.Done(): + return + } + } +} + +// close closes the TSO client +func (c *tsoClient) close() { if c == nil { return } - log.Info("closing tso client") + log.Info("[tso] closing tso client") c.cancel() c.wg.Wait() - log.Info("close tso client") + log.Info("[tso] close tso client") c.closeTSODispatcher() - log.Info("tso client is closed") -} - -func (c *tsoClient) scheduleCheckTSDeadline() { - select { - case c.checkTSDeadlineCh <- struct{}{}: - default: - } + log.Info("[tso] tso client is closed") } func (c *tsoClient) scheduleCheckTSODispatcher() { @@ -153,11 +169,21 @@ func (c *tsoClient) scheduleCheckTSODispatcher() { } } -func (c *tsoClient) scheduleUpdateTSOConnectionCtxs() { - select { - case c.updateTSOConnectionCtxsCh <- struct{}{}: - default: +// scheduleUpdateAllTSOConnectionCtxs update the TSO connection contexts for all dc-locations. +func (c *tsoClient) scheduleUpdateAllTSOConnectionCtxs() { + c.tsoDispatcher.Range(func(_, dispatcher any) bool { + dispatcher.(*tsoDispatcher).scheduleUpdateConnectionCtxs() + return true + }) +} + +// scheduleUpdateTSOConnectionCtxs update the TSO connection contexts for the given dc-location. +func (c *tsoClient) scheduleUpdateTSOConnectionCtxs(dcLocation string) { + dispatcher, ok := c.getTSODispatcher(dcLocation) + if !ok { + return } + dispatcher.scheduleUpdateConnectionCtxs() } // TSO Follower Proxy only supports the Global TSO proxy now. @@ -178,6 +204,14 @@ func (c *tsoClient) getTSORequest(ctx context.Context, dcLocation string) *tsoRe return req } +func (c *tsoClient) getTSODispatcher(dcLocation string) (*tsoDispatcher, bool) { + dispatcher, ok := c.tsoDispatcher.Load(dcLocation) + if !ok || dispatcher == nil { + return nil, false + } + return dispatcher.(*tsoDispatcher), true +} + // GetTSOAllocators returns {dc-location -> TSO allocator leader URL} connection map func (c *tsoClient) GetTSOAllocators() *sync.Map { return &c.tsoAllocators @@ -192,14 +226,12 @@ func (c *tsoClient) GetTSOAllocatorServingURLByDCLocation(dcLocation string) (st return url.(string), true } -// GetTSOAllocatorClientConnByDCLocation returns the tso allocator grpc client connection -// of the given dcLocation +// GetTSOAllocatorClientConnByDCLocation returns the TSO allocator gRPC client connection of the given dcLocation. func (c *tsoClient) GetTSOAllocatorClientConnByDCLocation(dcLocation string) (*grpc.ClientConn, string) { url, ok := c.tsoAllocators.Load(dcLocation) if !ok { - panic(fmt.Sprintf("the allocator leader in %s should exist", dcLocation)) + log.Fatal("[tso] the allocator leader should exist", zap.String("dc-location", dcLocation)) } - // todo: if we support local tso forward, we should get or create client conns. cc, ok := c.svcDiscovery.GetClientConns().Load(url) if !ok { return nil, url.(string) @@ -242,6 +274,8 @@ func (c *tsoClient) updateTSOLocalServURLs(allocatorMap map[string]string) error zap.String("dc-location", dcLocation), zap.String("new-url", url), zap.String("old-url", oldURL)) + // Should trigger the update of the connection contexts once the allocator leader is switched. + c.scheduleUpdateTSOConnectionCtxs(dcLocation) } // Garbage collection of the old TSO allocator primaries @@ -259,6 +293,7 @@ func (c *tsoClient) updateTSOGlobalServURL(url string) error { log.Info("[tso] switch dc tso global allocator serving url", zap.String("dc-location", globalDCLocation), zap.String("new-url", url)) + c.scheduleUpdateTSOConnectionCtxs(globalDCLocation) c.scheduleCheckTSODispatcher() return nil } @@ -306,22 +341,27 @@ func (c *tsoClient) backupClientConn() (*grpc.ClientConn, string) { return nil, "" } +// tsoConnectionContext is used to store the context of a TSO stream connection. type tsoConnectionContext struct { - streamURL string - // Current stream to send gRPC requests, pdpb.PD_TsoClient for a leader/follower in the PD cluster, - // or tsopb.TSO_TsoClient for a primary/secondary in the TSO cluster - stream tsoStream ctx context.Context cancel context.CancelFunc + // Current URL of the stream connection. + streamURL string + // Current stream to send gRPC requests. + // - `pdpb.PD_TsoClient` for a leader/follower in the PD cluster. + // - `tsopb.TSO_TsoClient` for a primary/secondary in the TSO cluster. + stream tsoStream } -func (c *tsoClient) updateTSOConnectionCtxs(updaterCtx context.Context, dc string, connectionCtxs *sync.Map) bool { +// updateConnectionCtxs will choose the proper way to update the connections for the given dc-location. +// It will return a bool to indicate whether the update is successful. +func (c *tsoClient) updateConnectionCtxs(ctx context.Context, dc string, connectionCtxs *sync.Map) bool { // Normal connection creating, it will be affected by the `enableForwarding`. createTSOConnection := c.tryConnectToTSO if c.allowTSOFollowerProxy(dc) { createTSOConnection = c.tryConnectToTSOWithProxy } - if err := createTSOConnection(updaterCtx, dc, connectionCtxs); err != nil { + if err := createTSOConnection(ctx, dc, connectionCtxs); err != nil { log.Error("[tso] update connection contexts failed", zap.String("dc", dc), errs.ZapError(err)) return false } @@ -333,47 +373,48 @@ func (c *tsoClient) updateTSOConnectionCtxs(updaterCtx context.Context, dc strin // while a new daemon will be created also to switch back to a normal leader connection ASAP the // connection comes back to normal. func (c *tsoClient) tryConnectToTSO( - dispatcherCtx context.Context, + ctx context.Context, dc string, connectionCtxs *sync.Map, ) error { var ( - networkErrNum uint64 - err error - stream tsoStream - url string - cc *grpc.ClientConn - ) - updateAndClear := func(newURL string, connectionCtx *tsoConnectionContext) { - if cc, loaded := connectionCtxs.LoadOrStore(newURL, connectionCtx); loaded { - // If the previous connection still exists, we should close it first. - cc.(*tsoConnectionContext).cancel() - connectionCtxs.Store(newURL, connectionCtx) + networkErrNum uint64 + err error + stream tsoStream + url string + cc *grpc.ClientConn + updateAndClear = func(newURL string, connectionCtx *tsoConnectionContext) { + // Only store the `connectionCtx` if it does not exist before. + connectionCtxs.LoadOrStore(newURL, connectionCtx) + // Remove all other `connectionCtx`s. + connectionCtxs.Range(func(url, cc any) bool { + if url.(string) != newURL { + cc.(*tsoConnectionContext).cancel() + connectionCtxs.Delete(url) + } + return true + }) } - connectionCtxs.Range(func(url, cc any) bool { - if url.(string) != newURL { - cc.(*tsoConnectionContext).cancel() - connectionCtxs.Delete(url) - } - return true - }) - } - // retry several times before falling back to the follower when the network problem happens + ) ticker := time.NewTicker(retryInterval) defer ticker.Stop() + // Retry several times before falling back to the follower when the network problem happens for i := 0; i < maxRetryTimes; i++ { c.svcDiscovery.ScheduleCheckMemberChanged() cc, url = c.GetTSOAllocatorClientConnByDCLocation(dc) + if _, ok := connectionCtxs.Load(url); ok { + return nil + } if cc != nil { - cctx, cancel := context.WithCancel(dispatcherCtx) + cctx, cancel := context.WithCancel(ctx) stream, err = c.tsoStreamBuilderFactory.makeBuilder(cc).build(cctx, cancel, c.option.timeout) failpoint.Inject("unreachableNetwork", func() { stream = nil err = status.New(codes.Unavailable, "unavailable").Err() }) if stream != nil && err == nil { - updateAndClear(url, &tsoConnectionContext{url, stream, cctx, cancel}) + updateAndClear(url, &tsoConnectionContext{cctx, cancel, url, stream}) return nil } @@ -392,7 +433,7 @@ func (c *tsoClient) tryConnectToTSO( networkErrNum++ } select { - case <-dispatcherCtx.Done(): + case <-ctx.Done(): return err case <-ticker.C: } @@ -409,16 +450,16 @@ func (c *tsoClient) tryConnectToTSO( } // create the follower stream - cctx, cancel := context.WithCancel(dispatcherCtx) + cctx, cancel := context.WithCancel(ctx) cctx = grpcutil.BuildForwardContext(cctx, forwardedHost) stream, err = c.tsoStreamBuilderFactory.makeBuilder(backupClientConn).build(cctx, cancel, c.option.timeout) if err == nil { forwardedHostTrim := trimHTTPPrefix(forwardedHost) addr := trimHTTPPrefix(backupURL) // the goroutine is used to check the network and change back to the original stream - go c.checkAllocator(dispatcherCtx, cancel, dc, forwardedHostTrim, addr, url, updateAndClear) + go c.checkAllocator(ctx, cancel, dc, forwardedHostTrim, addr, url, updateAndClear) requestForwarded.WithLabelValues(forwardedHostTrim, addr).Set(1) - updateAndClear(backupURL, &tsoConnectionContext{backupURL, stream, cctx, cancel}) + updateAndClear(backupURL, &tsoConnectionContext{cctx, cancel, backupURL, stream}) return nil } cancel() @@ -427,9 +468,66 @@ func (c *tsoClient) tryConnectToTSO( return err } +func (c *tsoClient) checkAllocator( + ctx context.Context, + forwardCancel context.CancelFunc, + dc, forwardedHostTrim, addr, url string, + updateAndClear func(newAddr string, connectionCtx *tsoConnectionContext), +) { + defer func() { + // cancel the forward stream + forwardCancel() + requestForwarded.WithLabelValues(forwardedHostTrim, addr).Set(0) + }() + cc, u := c.GetTSOAllocatorClientConnByDCLocation(dc) + var healthCli healthpb.HealthClient + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + // the pd/allocator leader change, we need to re-establish the stream + if u != url { + log.Info("[tso] the leader of the allocator leader is changed", zap.String("dc", dc), zap.String("origin", url), zap.String("new", u)) + return + } + if healthCli == nil && cc != nil { + healthCli = healthpb.NewHealthClient(cc) + } + if healthCli != nil { + healthCtx, healthCancel := context.WithTimeout(ctx, c.option.timeout) + resp, err := healthCli.Check(healthCtx, &healthpb.HealthCheckRequest{Service: ""}) + failpoint.Inject("unreachableNetwork", func() { + resp.Status = healthpb.HealthCheckResponse_UNKNOWN + }) + healthCancel() + if err == nil && resp.GetStatus() == healthpb.HealthCheckResponse_SERVING { + // create a stream of the original allocator + cctx, cancel := context.WithCancel(ctx) + stream, err := c.tsoStreamBuilderFactory.makeBuilder(cc).build(cctx, cancel, c.option.timeout) + if err == nil && stream != nil { + log.Info("[tso] recover the original tso stream since the network has become normal", zap.String("dc", dc), zap.String("url", url)) + updateAndClear(url, &tsoConnectionContext{cctx, cancel, url, stream}) + return + } + } + } + select { + case <-ctx.Done(): + return + case <-ticker.C: + // To ensure we can get the latest allocator leader + // and once the leader is changed, we can exit this function. + cc, u = c.GetTSOAllocatorClientConnByDCLocation(dc) + } + } +} + // tryConnectToTSOWithProxy will create multiple streams to all the service endpoints to work as // a TSO proxy to reduce the pressure of the main serving service endpoint. -func (c *tsoClient) tryConnectToTSOWithProxy(dispatcherCtx context.Context, dc string, connectionCtxs *sync.Map) error { +func (c *tsoClient) tryConnectToTSOWithProxy( + ctx context.Context, + dc string, + connectionCtxs *sync.Map, +) error { tsoStreamBuilders := c.getAllTSOStreamBuilders() leaderAddr := c.svcDiscovery.GetServingURL() forwardedHost, ok := c.GetTSOAllocatorServingURLByDCLocation(dc) @@ -455,7 +553,7 @@ func (c *tsoClient) tryConnectToTSOWithProxy(dispatcherCtx context.Context, dc s } log.Info("[tso] try to create tso stream", zap.String("dc", dc), zap.String("addr", addr)) - cctx, cancel := context.WithCancel(dispatcherCtx) + cctx, cancel := context.WithCancel(ctx) // Do not proxy the leader client. if addr != leaderAddr { log.Info("[tso] use follower to forward tso stream to do the proxy", @@ -470,7 +568,7 @@ func (c *tsoClient) tryConnectToTSOWithProxy(dispatcherCtx context.Context, dc s addrTrim := trimHTTPPrefix(addr) requestForwarded.WithLabelValues(forwardedHostTrim, addrTrim).Set(1) } - connectionCtxs.Store(addr, &tsoConnectionContext{addr, stream, cctx, cancel}) + connectionCtxs.Store(addr, &tsoConnectionContext{cctx, cancel, addr, stream}) continue } log.Error("[tso] create the tso stream failed", @@ -506,92 +604,90 @@ func (c *tsoClient) getAllTSOStreamBuilders() map[string]tsoStreamBuilder { return streamBuilders } -type tsoInfo struct { - tsoServer string - reqKeyspaceGroupID uint32 - respKeyspaceGroupID uint32 - respReceivedAt time.Time - physical int64 - logical int64 +func (c *tsoClient) createTSODispatcher(dcLocation string) { + dispatcher := newTSODispatcher(c.ctx, dcLocation, defaultMaxTSOBatchSize, c) + if _, ok := c.tsoDispatcher.LoadOrStore(dcLocation, dispatcher); !ok { + // Create a new dispatcher for the dc-location to handle the TSO requests. + c.wg.Add(1) + go dispatcher.handleDispatcher(&c.wg) + } else { + dispatcher.close() + } } -func (c *tsoClient) processRequests( - stream tsoStream, dcLocation string, tbc *tsoBatchController, -) error { - requests := tbc.getCollectedRequests() - // nolint - for _, req := range requests { - defer trace.StartRegion(req.requestCtx, "pdclient.tsoReqSend").End() - if span := opentracing.SpanFromContext(req.requestCtx); span != nil && span.Tracer() != nil { - span = span.Tracer().StartSpan("pdclient.processRequests", opentracing.ChildOf(span.Context())) - defer span.Finish() +func (c *tsoClient) closeTSODispatcher() { + c.tsoDispatcher.Range(func(_, dispatcherInterface any) bool { + if dispatcherInterface != nil { + dispatcherInterface.(*tsoDispatcher).close() } - } + return true + }) +} - count := int64(len(requests)) - reqKeyspaceGroupID := c.svcDiscovery.GetKeyspaceGroupID() - respKeyspaceGroupID, physical, logical, suffixBits, err := stream.processRequests( - c.svcDiscovery.GetClusterID(), c.svcDiscovery.GetKeyspaceID(), reqKeyspaceGroupID, - dcLocation, count, tbc.batchStartTime) - if err != nil { - tbc.finishCollectedRequests(0, 0, 0, err) - return err - } - // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. - firstLogical := tsoutil.AddLogical(logical, -count+1, suffixBits) - curTSOInfo := &tsoInfo{ - tsoServer: stream.getServerURL(), - reqKeyspaceGroupID: reqKeyspaceGroupID, - respKeyspaceGroupID: respKeyspaceGroupID, - respReceivedAt: time.Now(), - physical: physical, - logical: tsoutil.AddLogical(firstLogical, count-1, suffixBits), - } - c.compareAndSwapTS(dcLocation, curTSOInfo, physical, firstLogical) - tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil) - return nil +func (c *tsoClient) updateTSODispatcher() { + // Set up the new TSO dispatcher and batch controller. + c.GetTSOAllocators().Range(func(dcLocationKey, _ any) bool { + dcLocation := dcLocationKey.(string) + if _, ok := c.getTSODispatcher(dcLocation); !ok { + c.createTSODispatcher(dcLocation) + } + return true + }) + // Clean up the unused TSO dispatcher + c.tsoDispatcher.Range(func(dcLocationKey, dispatcher any) bool { + dcLocation := dcLocationKey.(string) + // Skip the Global TSO Allocator + if dcLocation == globalDCLocation { + return true + } + if _, exist := c.GetTSOAllocators().Load(dcLocation); !exist { + log.Info("[tso] delete unused tso dispatcher", zap.String("dc-location", dcLocation)) + c.tsoDispatcher.Delete(dcLocation) + dispatcher.(*tsoDispatcher).close() + } + return true + }) } -func (c *tsoClient) compareAndSwapTS( - dcLocation string, - curTSOInfo *tsoInfo, - physical, firstLogical int64, -) { - val, loaded := c.lastTSOInfoMap.LoadOrStore(dcLocation, curTSOInfo) - if !loaded { - return - } - lastTSOInfo := val.(*tsoInfo) - if lastTSOInfo.respKeyspaceGroupID != curTSOInfo.respKeyspaceGroupID { - log.Info("[tso] keyspace group changed", - zap.String("dc-location", dcLocation), - zap.Uint32("old-group-id", lastTSOInfo.respKeyspaceGroupID), - zap.Uint32("new-group-id", curTSOInfo.respKeyspaceGroupID)) +// dispatchRequest will send the TSO request to the corresponding TSO dispatcher. +func (c *tsoClient) dispatchRequest(request *tsoRequest) (bool, error) { + dispatcher, ok := c.getTSODispatcher(request.dcLocation) + if !ok { + err := errs.ErrClientGetTSO.FastGenByArgs(fmt.Sprintf("unknown dc-location %s to the client", request.dcLocation)) + log.Error("[tso] dispatch tso request error", zap.String("dc-location", request.dcLocation), errs.ZapError(err)) + c.svcDiscovery.ScheduleCheckMemberChanged() + // New dispatcher could be created in the meantime, which is retryable. + return true, err } - // The TSO we get is a range like [largestLogical-count+1, largestLogical], so we save the last TSO's largest logical - // to compare with the new TSO's first logical. For example, if we have a TSO resp with logical 10, count 5, then - // all TSOs we get will be [6, 7, 8, 9, 10]. lastTSOInfo.logical stores the logical part of the largest ts returned - // last time. - if tsoutil.TSLessEqual(physical, firstLogical, lastTSOInfo.physical, lastTSOInfo.logical) { - log.Panic("[tso] timestamp fallback", - zap.String("dc-location", dcLocation), - zap.Uint32("keyspace", c.svcDiscovery.GetKeyspaceID()), - zap.String("last-ts", fmt.Sprintf("(%d, %d)", lastTSOInfo.physical, lastTSOInfo.logical)), - zap.String("cur-ts", fmt.Sprintf("(%d, %d)", physical, firstLogical)), - zap.String("last-tso-server", lastTSOInfo.tsoServer), - zap.String("cur-tso-server", curTSOInfo.tsoServer), - zap.Uint32("last-keyspace-group-in-request", lastTSOInfo.reqKeyspaceGroupID), - zap.Uint32("cur-keyspace-group-in-request", curTSOInfo.reqKeyspaceGroupID), - zap.Uint32("last-keyspace-group-in-response", lastTSOInfo.respKeyspaceGroupID), - zap.Uint32("cur-keyspace-group-in-response", curTSOInfo.respKeyspaceGroupID), - zap.Time("last-response-received-at", lastTSOInfo.respReceivedAt), - zap.Time("cur-response-received-at", curTSOInfo.respReceivedAt)) + defer trace.StartRegion(request.requestCtx, "pdclient.tsoReqEnqueue").End() + select { + case <-request.requestCtx.Done(): + // Caller cancelled the request, no need to retry. + return false, request.requestCtx.Err() + case <-request.clientCtx.Done(): + // Client is closed, no need to retry. + return false, request.clientCtx.Err() + case <-c.ctx.Done(): + // tsoClient is closed due to the PD service mode switch, which is retryable. + return true, c.ctx.Err() + default: + // This failpoint will increase the possibility that the request is sent to a closed dispatcher. + failpoint.Inject("delayDispatchTSORequest", func() { + time.Sleep(time.Second) + }) + dispatcher.push(request) + } + // Check the contexts again to make sure the request is not been sent to a closed dispatcher. + // Never retry on these conditions to prevent unexpected data race. + select { + case <-request.requestCtx.Done(): + return false, request.requestCtx.Err() + case <-request.clientCtx.Done(): + return false, request.clientCtx.Err() + case <-c.ctx.Done(): + return false, c.ctx.Err() + default: } - lastTSOInfo.tsoServer = curTSOInfo.tsoServer - lastTSOInfo.reqKeyspaceGroupID = curTSOInfo.reqKeyspaceGroupID - lastTSOInfo.respKeyspaceGroupID = curTSOInfo.respKeyspaceGroupID - lastTSOInfo.respReceivedAt = curTSOInfo.respReceivedAt - lastTSOInfo.physical = curTSOInfo.physical - lastTSOInfo.logical = curTSOInfo.logical + return false, nil } diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 7528293a7338..d5b52ad60390 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -22,106 +22,19 @@ import ( "sync" "time" + "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/log" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/retry" "github.com/tikv/pd/client/timerpool" + "github.com/tikv/pd/client/tsoutil" "go.uber.org/zap" - healthpb "google.golang.org/grpc/health/grpc_health_v1" ) -const ( - tsLoopDCCheckInterval = time.Minute - defaultMaxTSOBatchSize = 10000 // should be higher if client is sending requests in burst - retryInterval = 500 * time.Millisecond - maxRetryTimes = 6 -) - -type tsoDispatcher struct { - dispatcherCancel context.CancelFunc - tsoBatchController *tsoBatchController -} - -func (c *tsoClient) dispatchRequest(request *tsoRequest) (bool, error) { - dispatcher, ok := c.tsoDispatcher.Load(request.dcLocation) - if !ok { - err := errs.ErrClientGetTSO.FastGenByArgs(fmt.Sprintf("unknown dc-location %s to the client", request.dcLocation)) - log.Error("[tso] dispatch tso request error", zap.String("dc-location", request.dcLocation), errs.ZapError(err)) - c.svcDiscovery.ScheduleCheckMemberChanged() - // New dispatcher could be created in the meantime, which is retryable. - return true, err - } - - defer trace.StartRegion(request.requestCtx, "pdclient.tsoReqEnqueue").End() - select { - case <-request.requestCtx.Done(): - // Caller cancelled the request, no need to retry. - return false, request.requestCtx.Err() - case <-request.clientCtx.Done(): - // Client is closed, no need to retry. - return false, request.clientCtx.Err() - case <-c.ctx.Done(): - // tsoClient is closed due to the PD service mode switch, which is retryable. - return true, c.ctx.Err() - default: - // This failpoint will increase the possibility that the request is sent to a closed dispatcher. - failpoint.Inject("delayDispatchTSORequest", func() { - time.Sleep(time.Second) - }) - dispatcher.(*tsoDispatcher).tsoBatchController.tsoRequestCh <- request - } - // Check the contexts again to make sure the request is not been sent to a closed dispatcher. - // Never retry on these conditions to prevent unexpected data race. - select { - case <-request.requestCtx.Done(): - return false, request.requestCtx.Err() - case <-request.clientCtx.Done(): - return false, request.clientCtx.Err() - case <-c.ctx.Done(): - return false, c.ctx.Err() - default: - } - return false, nil -} - -func (c *tsoClient) closeTSODispatcher() { - c.tsoDispatcher.Range(func(_, dispatcherInterface any) bool { - if dispatcherInterface != nil { - dispatcher := dispatcherInterface.(*tsoDispatcher) - dispatcher.dispatcherCancel() - dispatcher.tsoBatchController.clear() - } - return true - }) -} - -func (c *tsoClient) updateTSODispatcher() { - // Set up the new TSO dispatcher and batch controller. - c.GetTSOAllocators().Range(func(dcLocationKey, _ any) bool { - dcLocation := dcLocationKey.(string) - if !c.checkTSODispatcher(dcLocation) { - c.createTSODispatcher(dcLocation) - } - return true - }) - // Clean up the unused TSO dispatcher - c.tsoDispatcher.Range(func(dcLocationKey, dispatcher any) bool { - dcLocation := dcLocationKey.(string) - // Skip the Global TSO Allocator - if dcLocation == globalDCLocation { - return true - } - if _, exist := c.GetTSOAllocators().Load(dcLocation); !exist { - log.Info("[tso] delete unused tso dispatcher", zap.String("dc-location", dcLocation)) - dispatcher.(*tsoDispatcher).dispatcherCancel() - c.tsoDispatcher.Delete(dcLocation) - } - return true - }) -} - +// deadline is used to control the TS request timeout manually, +// it will be sent to the `tsDeadlineCh` to be handled by the `watchTSDeadline` goroutine. type deadline struct { timer *time.Timer done chan struct{} @@ -141,184 +54,119 @@ func newTSDeadline( } } -func (c *tsoClient) tsCancelLoop() { - defer c.wg.Done() - - tsCancelLoopCtx, tsCancelLoopCancel := context.WithCancel(c.ctx) - defer tsCancelLoopCancel() - - ticker := time.NewTicker(tsLoopDCCheckInterval) - defer ticker.Stop() - for { - // Watch every dc-location's tsDeadlineCh - c.GetTSOAllocators().Range(func(dcLocation, _ any) bool { - c.watchTSDeadline(tsCancelLoopCtx, dcLocation.(string)) - return true - }) - select { - case <-c.checkTSDeadlineCh: - continue - case <-ticker.C: - continue - case <-tsCancelLoopCtx.Done(): - log.Info("exit tso requests cancel loop") - return - } - } +type tsoInfo struct { + tsoServer string + reqKeyspaceGroupID uint32 + respKeyspaceGroupID uint32 + respReceivedAt time.Time + physical int64 + logical int64 } -func (c *tsoClient) watchTSDeadline(ctx context.Context, dcLocation string) { - if _, exist := c.tsDeadline.Load(dcLocation); !exist { - tsDeadlineCh := make(chan *deadline, 1) - c.tsDeadline.Store(dcLocation, tsDeadlineCh) - go func(dc string, tsDeadlineCh <-chan *deadline) { - for { - select { - case d := <-tsDeadlineCh: - select { - case <-d.timer.C: - log.Error("[tso] tso request is canceled due to timeout", zap.String("dc-location", dc), errs.ZapError(errs.ErrClientGetTSOTimeout)) - d.cancel() - timerpool.GlobalTimerPool.Put(d.timer) - case <-d.done: - timerpool.GlobalTimerPool.Put(d.timer) - case <-ctx.Done(): - timerpool.GlobalTimerPool.Put(d.timer) - return - } - case <-ctx.Done(): - return - } - } - }(dcLocation, tsDeadlineCh) - } +type tsoServiceProvider interface { + getOption() *option + getServiceDiscovery() ServiceDiscovery + updateConnectionCtxs(ctx context.Context, dc string, connectionCtxs *sync.Map) bool } -func (c *tsoClient) tsoDispatcherCheckLoop() { - defer c.wg.Done() +type tsoDispatcher struct { + ctx context.Context + cancel context.CancelFunc + dc string - loopCtx, loopCancel := context.WithCancel(c.ctx) - defer loopCancel() + provider tsoServiceProvider + // URL -> *connectionContext + connectionCtxs *sync.Map + batchController *tsoBatchController + tsDeadlineCh chan *deadline + lastTSOInfo *tsoInfo - ticker := time.NewTicker(tsLoopDCCheckInterval) - defer ticker.Stop() - for { - c.updateTSODispatcher() - select { - case <-ticker.C: - case <-c.checkTSODispatcherCh: - case <-loopCtx.Done(): - log.Info("exit tso dispatcher loop") - return - } + updateConnectionCtxsCh chan struct{} +} + +func newTSODispatcher( + ctx context.Context, + dc string, + maxBatchSize int, + provider tsoServiceProvider, +) *tsoDispatcher { + dispatcherCtx, dispatcherCancel := context.WithCancel(ctx) + tsoBatchController := newTSOBatchController( + make(chan *tsoRequest, maxBatchSize*2), + maxBatchSize, + ) + failpoint.Inject("shortDispatcherChannel", func() { + tsoBatchController = newTSOBatchController( + make(chan *tsoRequest, 1), + maxBatchSize, + ) + }) + td := &tsoDispatcher{ + ctx: dispatcherCtx, + cancel: dispatcherCancel, + dc: dc, + provider: provider, + connectionCtxs: &sync.Map{}, + batchController: tsoBatchController, + tsDeadlineCh: make(chan *deadline, 1), + updateConnectionCtxsCh: make(chan struct{}, 1), } + go td.watchTSDeadline() + return td } -func (c *tsoClient) checkAllocator( - dispatcherCtx context.Context, - forwardCancel context.CancelFunc, - dc, forwardedHostTrim, addr, url string, - updateAndClear func(newAddr string, connectionCtx *tsoConnectionContext)) { - defer func() { - // cancel the forward stream - forwardCancel() - requestForwarded.WithLabelValues(forwardedHostTrim, addr).Set(0) - }() - cc, u := c.GetTSOAllocatorClientConnByDCLocation(dc) - var healthCli healthpb.HealthClient - ticker := time.NewTicker(time.Second) - defer ticker.Stop() +func (td *tsoDispatcher) watchTSDeadline() { + log.Info("[tso] start tso deadline watcher", zap.String("dc-location", td.dc)) + defer log.Info("[tso] exit tso deadline watcher", zap.String("dc-location", td.dc)) for { - // the pd/allocator leader change, we need to re-establish the stream - if u != url { - log.Info("[tso] the leader of the allocator leader is changed", zap.String("dc", dc), zap.String("origin", url), zap.String("new", u)) - return - } - if healthCli == nil && cc != nil { - healthCli = healthpb.NewHealthClient(cc) - } - if healthCli != nil { - healthCtx, healthCancel := context.WithTimeout(dispatcherCtx, c.option.timeout) - resp, err := healthCli.Check(healthCtx, &healthpb.HealthCheckRequest{Service: ""}) - failpoint.Inject("unreachableNetwork", func() { - resp.Status = healthpb.HealthCheckResponse_UNKNOWN - }) - healthCancel() - if err == nil && resp.GetStatus() == healthpb.HealthCheckResponse_SERVING { - // create a stream of the original allocator - cctx, cancel := context.WithCancel(dispatcherCtx) - stream, err := c.tsoStreamBuilderFactory.makeBuilder(cc).build(cctx, cancel, c.option.timeout) - if err == nil && stream != nil { - log.Info("[tso] recover the original tso stream since the network has become normal", zap.String("dc", dc), zap.String("url", url)) - updateAndClear(url, &tsoConnectionContext{url, stream, cctx, cancel}) - return - } - } - } select { - case <-dispatcherCtx.Done(): + case d := <-td.tsDeadlineCh: + select { + case <-d.timer.C: + log.Error("[tso] tso request is canceled due to timeout", + zap.String("dc-location", td.dc), errs.ZapError(errs.ErrClientGetTSOTimeout)) + d.cancel() + timerpool.GlobalTimerPool.Put(d.timer) + case <-d.done: + timerpool.GlobalTimerPool.Put(d.timer) + case <-td.ctx.Done(): + timerpool.GlobalTimerPool.Put(d.timer) + return + } + case <-td.ctx.Done(): return - case <-ticker.C: - // To ensure we can get the latest allocator leader - // and once the leader is changed, we can exit this function. - cc, u = c.GetTSOAllocatorClientConnByDCLocation(dc) } } } -func (c *tsoClient) checkTSODispatcher(dcLocation string) bool { - dispatcher, ok := c.tsoDispatcher.Load(dcLocation) - if !ok || dispatcher == nil { - return false +func (td *tsoDispatcher) scheduleUpdateConnectionCtxs() { + select { + case td.updateConnectionCtxsCh <- struct{}{}: + default: } - return true } -func (c *tsoClient) createTSODispatcher(dcLocation string) { - dispatcherCtx, dispatcherCancel := context.WithCancel(c.ctx) - dispatcher := &tsoDispatcher{ - dispatcherCancel: dispatcherCancel, - tsoBatchController: newTSOBatchController( - make(chan *tsoRequest, defaultMaxTSOBatchSize*2), - defaultMaxTSOBatchSize), - } - failpoint.Inject("shortDispatcherChannel", func() { - dispatcher = &tsoDispatcher{ - dispatcherCancel: dispatcherCancel, - tsoBatchController: newTSOBatchController( - make(chan *tsoRequest, 1), - defaultMaxTSOBatchSize), - } - }) +func (td *tsoDispatcher) close() { + td.cancel() + td.batchController.clear() +} - if _, ok := c.tsoDispatcher.LoadOrStore(dcLocation, dispatcher); !ok { - // Successfully stored the value. Start the following goroutine. - // Each goroutine is responsible for handling the tso stream request for its dc-location. - // The only case that will make the dispatcher goroutine exit - // is that the loopCtx is done, otherwise there is no circumstance - // this goroutine should exit. - c.wg.Add(1) - go c.handleDispatcher(dispatcherCtx, dcLocation, dispatcher.tsoBatchController) - log.Info("[tso] tso dispatcher created", zap.String("dc-location", dcLocation)) - } else { - dispatcherCancel() - } +func (td *tsoDispatcher) push(request *tsoRequest) { + td.batchController.tsoRequestCh <- request } -func (c *tsoClient) handleDispatcher( - dispatcherCtx context.Context, - dc string, - tbc *tsoBatchController, -) { +func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { var ( - err error - streamURL string - stream tsoStream - streamCtx context.Context - cancel context.CancelFunc - // url -> connectionContext - connectionCtxs sync.Map + ctx = td.ctx + dc = td.dc + provider = td.provider + svcDiscovery = provider.getServiceDiscovery() + option = provider.getOption() + connectionCtxs = td.connectionCtxs + batchController = td.batchController ) + log.Info("[tso] tso dispatcher created", zap.String("dc-location", dc)) + // Clean up the connectionCtxs when the dispatcher exits. defer func() { log.Info("[tso] exit tso dispatcher", zap.String("dc-location", dc)) // Cancel all connections. @@ -327,73 +175,37 @@ func (c *tsoClient) handleDispatcher( return true }) // Clear the tso batch controller. - tbc.clear() - c.wg.Done() + batchController.clear() + wg.Done() }() - // Call updateTSOConnectionCtxs once to init the connectionCtxs first. - c.updateTSOConnectionCtxs(dispatcherCtx, dc, &connectionCtxs) - // Only the Global TSO needs to watch the updateTSOConnectionCtxsCh to sense the - // change of the cluster when TSO Follower Proxy is enabled. - // TODO: support TSO Follower Proxy for the Local TSO. - if dc == globalDCLocation { - go func() { - var updateTicker = &time.Ticker{} - setNewUpdateTicker := func(ticker *time.Ticker) { - if updateTicker.C != nil { - updateTicker.Stop() - } - updateTicker = ticker - } - // Set to nil before returning to ensure that the existing ticker can be GC. - defer setNewUpdateTicker(nil) - - for { - select { - case <-dispatcherCtx.Done(): - return - case <-c.option.enableTSOFollowerProxyCh: - enableTSOFollowerProxy := c.option.getEnableTSOFollowerProxy() - log.Info("[tso] tso follower proxy status changed", - zap.String("dc-location", dc), - zap.Bool("enable", enableTSOFollowerProxy)) - if enableTSOFollowerProxy && updateTicker.C == nil { - // Because the TSO Follower Proxy is enabled, - // the periodic check needs to be performed. - setNewUpdateTicker(time.NewTicker(memberUpdateInterval)) - } else if !enableTSOFollowerProxy && updateTicker.C != nil { - // Because the TSO Follower Proxy is disabled, - // the periodic check needs to be turned off. - setNewUpdateTicker(&time.Ticker{}) - } else { - // The status of TSO Follower Proxy does not change, and updateTSOConnectionCtxs is not triggered - continue - } - case <-updateTicker.C: - case <-c.updateTSOConnectionCtxsCh: - } - c.updateTSOConnectionCtxs(dispatcherCtx, dc, &connectionCtxs) - } - }() - } + // Daemon goroutine to update the connectionCtxs periodically and handle the `connectionCtxs` update event. + go td.connectionCtxsUpdater() + var ( + err error + streamCtx context.Context + cancel context.CancelFunc + streamURL string + stream tsoStream + ) // Loop through each batch of TSO requests and send them for processing. - streamLoopTimer := time.NewTimer(c.option.timeout) + streamLoopTimer := time.NewTimer(option.timeout) defer streamLoopTimer.Stop() bo := retry.InitialBackoffer(updateMemberBackOffBaseTime, updateMemberTimeout, updateMemberBackOffBaseTime) tsoBatchLoop: for { select { - case <-dispatcherCtx.Done(): + case <-ctx.Done(): return default: } // Start to collect the TSO requests. - maxBatchWaitInterval := c.option.getMaxTSOBatchWaitInterval() + maxBatchWaitInterval := option.getMaxTSOBatchWaitInterval() // Once the TSO requests are collected, must make sure they could be finished or revoked eventually, // otherwise the upper caller may get blocked on waiting for the results. - if err = tbc.fetchPendingRequests(dispatcherCtx, maxBatchWaitInterval); err != nil { + if err = batchController.fetchPendingRequests(ctx, maxBatchWaitInterval); err != nil { // Finish the collected requests if the fetch failed. - tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) + batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) if err == context.Canceled { log.Info("[tso] stop fetching the pending tso requests due to context canceled", zap.String("dc-location", dc)) @@ -405,7 +217,7 @@ tsoBatchLoop: return } if maxBatchWaitInterval >= 0 { - tbc.adjustBestBatchSize() + batchController.adjustBestBatchSize() } // Stop the timer if it's not stopped. if !streamLoopTimer.Stop() { @@ -416,33 +228,33 @@ tsoBatchLoop: } // We need be careful here, see more details in the comments of Timer.Reset. // https://pkg.go.dev/time@master#Timer.Reset - streamLoopTimer.Reset(c.option.timeout) + streamLoopTimer.Reset(option.timeout) // Choose a stream to send the TSO gRPC request. streamChoosingLoop: for { - connectionCtx := chooseStream(&connectionCtxs) + connectionCtx := chooseStream(connectionCtxs) if connectionCtx != nil { - streamURL, stream, streamCtx, cancel = connectionCtx.streamURL, connectionCtx.stream, connectionCtx.ctx, connectionCtx.cancel + streamCtx, cancel, streamURL, stream = connectionCtx.ctx, connectionCtx.cancel, connectionCtx.streamURL, connectionCtx.stream } // Check stream and retry if necessary. if stream == nil { log.Info("[tso] tso stream is not ready", zap.String("dc", dc)) - if c.updateTSOConnectionCtxs(dispatcherCtx, dc, &connectionCtxs) { + if provider.updateConnectionCtxs(ctx, dc, connectionCtxs) { continue streamChoosingLoop } timer := time.NewTimer(retryInterval) select { - case <-dispatcherCtx.Done(): + case <-ctx.Done(): // Finish the collected requests if the context is canceled. - tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(dispatcherCtx.Err())) + batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(ctx.Err())) timer.Stop() return case <-streamLoopTimer.C: err = errs.ErrClientCreateTSOStream.FastGenByArgs(errs.RetryTimeoutErr) log.Error("[tso] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err)) - c.svcDiscovery.ScheduleCheckMemberChanged() + svcDiscovery.ScheduleCheckMemberChanged() // Finish the collected requests if the stream is failed to be created. - tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) + batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) timer.Stop() continue tsoBatchLoop case <-timer.C: @@ -463,31 +275,25 @@ tsoBatchLoop: } } done := make(chan struct{}) - dl := newTSDeadline(c.option.timeout, done, cancel) - tsDeadlineCh, ok := c.tsDeadline.Load(dc) - for !ok || tsDeadlineCh == nil { - c.scheduleCheckTSDeadline() - time.Sleep(time.Millisecond * 100) - tsDeadlineCh, ok = c.tsDeadline.Load(dc) - } + dl := newTSDeadline(option.timeout, done, cancel) select { - case <-dispatcherCtx.Done(): + case <-ctx.Done(): // Finish the collected requests if the context is canceled. - tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(dispatcherCtx.Err())) + batchController.finishCollectedRequests(0, 0, 0, errors.WithStack(ctx.Err())) return - case tsDeadlineCh.(chan *deadline) <- dl: + case td.tsDeadlineCh <- dl: } // processRequests guarantees that the collected requests could be finished properly. - err = c.processRequests(stream, dc, tbc) + err = td.processRequests(stream, dc, td.batchController) close(done) // If error happens during tso stream handling, reset stream and run the next trial. if err != nil { select { - case <-dispatcherCtx.Done(): + case <-ctx.Done(): return default: } - c.svcDiscovery.ScheduleCheckMemberChanged() + svcDiscovery.ScheduleCheckMemberChanged() log.Error("[tso] getTS error after processing requests", zap.String("dc-location", dc), zap.String("stream-url", streamURL), @@ -498,24 +304,79 @@ tsoBatchLoop: stream = nil // Because ScheduleCheckMemberChanged is asynchronous, if the leader changes, we better call `updateMember` ASAP. if IsLeaderChange(err) { - if err := bo.Exec(dispatcherCtx, c.svcDiscovery.CheckMemberChanged); err != nil { + if err := bo.Exec(ctx, svcDiscovery.CheckMemberChanged); err != nil { select { - case <-dispatcherCtx.Done(): + case <-ctx.Done(): return default: } } // Because the TSO Follower Proxy could be configured online, - // If we change it from on -> off, background updateTSOConnectionCtxs + // If we change it from on -> off, background updateConnectionCtxs // will cancel the current stream, then the EOF error caused by cancel() - // should not trigger the updateTSOConnectionCtxs here. + // should not trigger the updateConnectionCtxs here. // So we should only call it when the leader changes. - c.updateTSOConnectionCtxs(dispatcherCtx, dc, &connectionCtxs) + provider.updateConnectionCtxs(ctx, dc, connectionCtxs) } } } } +// updateConnectionCtxs updates the `connectionCtxs` for the specified DC location regularly. +func (td *tsoDispatcher) connectionCtxsUpdater() { + var ( + ctx = td.ctx + dc = td.dc + connectionCtxs = td.connectionCtxs + provider = td.provider + option = td.provider.getOption() + updateTicker = &time.Ticker{} + ) + + log.Info("[tso] start tso connection contexts updater", zap.String("dc-location", dc)) + setNewUpdateTicker := func(ticker *time.Ticker) { + if updateTicker.C != nil { + updateTicker.Stop() + } + updateTicker = ticker + } + // Set to nil before returning to ensure that the existing ticker can be GC. + defer setNewUpdateTicker(nil) + + for { + provider.updateConnectionCtxs(ctx, dc, connectionCtxs) + select { + case <-ctx.Done(): + log.Info("[tso] exit tso connection contexts updater", zap.String("dc-location", dc)) + return + case <-option.enableTSOFollowerProxyCh: + // TODO: implement support of TSO Follower Proxy for the Local TSO. + if dc != globalDCLocation { + continue + } + enableTSOFollowerProxy := option.getEnableTSOFollowerProxy() + log.Info("[tso] tso follower proxy status changed", + zap.String("dc-location", dc), + zap.Bool("enable", enableTSOFollowerProxy)) + if enableTSOFollowerProxy && updateTicker.C == nil { + // Because the TSO Follower Proxy is enabled, + // the periodic check needs to be performed. + setNewUpdateTicker(time.NewTicker(memberUpdateInterval)) + } else if !enableTSOFollowerProxy && updateTicker.C != nil { + // Because the TSO Follower Proxy is disabled, + // the periodic check needs to be turned off. + setNewUpdateTicker(&time.Ticker{}) + } else { + continue + } + case <-updateTicker.C: + // Triggered periodically when the TSO Follower Proxy is enabled. + case <-td.updateConnectionCtxsCh: + // Triggered by the leader/follower change. + } + } +} + // chooseStream uses the reservoir sampling algorithm to randomly choose a connection. // connectionCtxs will only have only one stream to choose when the TSO Follower Proxy is off. func chooseStream(connectionCtxs *sync.Map) (connectionCtx *tsoConnectionContext) { @@ -530,3 +391,94 @@ func chooseStream(connectionCtxs *sync.Map) (connectionCtx *tsoConnectionContext }) return connectionCtx } + +func (td *tsoDispatcher) processRequests( + stream tsoStream, dcLocation string, tbc *tsoBatchController, +) error { + var ( + requests = tbc.getCollectedRequests() + traceRegions = make([]*trace.Region, 0, len(requests)) + spans = make([]opentracing.Span, 0, len(requests)) + ) + for _, req := range requests { + traceRegions = append(traceRegions, trace.StartRegion(req.requestCtx, "pdclient.tsoReqSend")) + if span := opentracing.SpanFromContext(req.requestCtx); span != nil && span.Tracer() != nil { + spans = append(spans, span.Tracer().StartSpan("pdclient.processRequests", opentracing.ChildOf(span.Context()))) + } + } + defer func() { + for i := range spans { + spans[i].Finish() + } + for i := range traceRegions { + traceRegions[i].End() + } + }() + + var ( + count = int64(len(requests)) + svcDiscovery = td.provider.getServiceDiscovery() + clusterID = svcDiscovery.GetClusterID() + keyspaceID = svcDiscovery.GetKeyspaceID() + reqKeyspaceGroupID = svcDiscovery.GetKeyspaceGroupID() + ) + respKeyspaceGroupID, physical, logical, suffixBits, err := stream.processRequests( + clusterID, keyspaceID, reqKeyspaceGroupID, + dcLocation, count, tbc.batchStartTime) + if err != nil { + tbc.finishCollectedRequests(0, 0, 0, err) + return err + } + curTSOInfo := &tsoInfo{ + tsoServer: stream.getServerURL(), + reqKeyspaceGroupID: reqKeyspaceGroupID, + respKeyspaceGroupID: respKeyspaceGroupID, + respReceivedAt: time.Now(), + physical: physical, + logical: logical, + } + // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. + firstLogical := tsoutil.AddLogical(logical, -count+1, suffixBits) + td.compareAndSwapTS(curTSOInfo, firstLogical) + tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil) + return nil +} + +func (td *tsoDispatcher) compareAndSwapTS( + curTSOInfo *tsoInfo, firstLogical int64, +) { + if td.lastTSOInfo != nil { + var ( + lastTSOInfo = td.lastTSOInfo + dc = td.dc + physical = curTSOInfo.physical + keyspaceID = td.provider.getServiceDiscovery().GetKeyspaceID() + ) + if td.lastTSOInfo.respKeyspaceGroupID != curTSOInfo.respKeyspaceGroupID { + log.Info("[tso] keyspace group changed", + zap.String("dc-location", dc), + zap.Uint32("old-group-id", lastTSOInfo.respKeyspaceGroupID), + zap.Uint32("new-group-id", curTSOInfo.respKeyspaceGroupID)) + } + // The TSO we get is a range like [largestLogical-count+1, largestLogical], so we save the last TSO's largest logical + // to compare with the new TSO's first logical. For example, if we have a TSO resp with logical 10, count 5, then + // all TSOs we get will be [6, 7, 8, 9, 10]. lastTSOInfo.logical stores the logical part of the largest ts returned + // last time. + if tsoutil.TSLessEqual(physical, firstLogical, lastTSOInfo.physical, lastTSOInfo.logical) { + log.Panic("[tso] timestamp fallback", + zap.String("dc-location", dc), + zap.Uint32("keyspace", keyspaceID), + zap.String("last-ts", fmt.Sprintf("(%d, %d)", lastTSOInfo.physical, lastTSOInfo.logical)), + zap.String("cur-ts", fmt.Sprintf("(%d, %d)", physical, firstLogical)), + zap.String("last-tso-server", lastTSOInfo.tsoServer), + zap.String("cur-tso-server", curTSOInfo.tsoServer), + zap.Uint32("last-keyspace-group-in-request", lastTSOInfo.reqKeyspaceGroupID), + zap.Uint32("cur-keyspace-group-in-request", curTSOInfo.reqKeyspaceGroupID), + zap.Uint32("last-keyspace-group-in-response", lastTSOInfo.respKeyspaceGroupID), + zap.Uint32("cur-keyspace-group-in-response", curTSOInfo.respKeyspaceGroupID), + zap.Time("last-response-received-at", lastTSOInfo.respReceivedAt), + zap.Time("cur-response-received-at", curTSOInfo.respReceivedAt)) + } + } + td.lastTSOInfo = curTSOInfo +} diff --git a/client/tso_request.go b/client/tso_request.go index f30ceb5268a6..b912fa354973 100644 --- a/client/tso_request.go +++ b/client/tso_request.go @@ -63,8 +63,8 @@ func (req *tsoRequest) Wait() (physical int64, logical int64, err error) { cmdDurationTSOAsyncWait.Observe(start.Sub(req.start).Seconds()) select { case err = <-req.done: - defer trace.StartRegion(req.requestCtx, "pdclient.tsoReqDone").End() defer req.pool.Put(req) + defer trace.StartRegion(req.requestCtx, "pdclient.tsoReqDone").End() err = errors.WithStack(err) if err != nil { cmdFailDurationTSO.Observe(time.Since(req.start).Seconds()) diff --git a/pkg/core/region.go b/pkg/core/region.go index efbe80194d17..be8f392f05e2 100644 --- a/pkg/core/region.go +++ b/pkg/core/region.go @@ -81,6 +81,8 @@ type RegionInfo struct { buckets unsafe.Pointer // source is used to indicate region's source, such as Storage/Sync/Heartbeat. source RegionSource + // ref is used to indicate the reference count of the region in root-tree and sub-tree. + ref atomic.Int32 } // RegionSource is the source of region. @@ -106,6 +108,21 @@ func (r *RegionInfo) LoadedFromSync() bool { return r.source == Sync } +// IncRef increases the reference count. +func (r *RegionInfo) IncRef() { + r.ref.Add(1) +} + +// DecRef decreases the reference count. +func (r *RegionInfo) DecRef() { + r.ref.Add(-1) +} + +// GetRef returns the reference count. +func (r *RegionInfo) GetRef() int32 { + return r.ref.Load() +} + // NewRegionInfo creates RegionInfo with region's meta and leader peer. func NewRegionInfo(region *metapb.Region, leader *metapb.Peer, opts ...RegionCreateOption) *RegionInfo { regionInfo := &RegionInfo{ @@ -903,7 +920,7 @@ type RegionsInfo struct { // NewRegionsInfo creates RegionsInfo with tree, regions, leaders and followers func NewRegionsInfo() *RegionsInfo { return &RegionsInfo{ - tree: newRegionTree(), + tree: newRegionTreeWithCountRef(), regions: make(map[uint64]*regionItem), subRegions: make(map[uint64]*regionItem), leaders: make(map[uint64]*regionTree), @@ -1092,10 +1109,14 @@ func (r *RegionsInfo) UpdateSubTreeOrderInsensitive(region *RegionInfo) { r.subRegions[region.GetID()] = item // It has been removed and all information needs to be updated again. // Set peers then. - setPeer := func(peersMap map[uint64]*regionTree, storeID uint64, item *regionItem) { + setPeer := func(peersMap map[uint64]*regionTree, storeID uint64, item *regionItem, countRef bool) { store, ok := peersMap[storeID] if !ok { - store = newRegionTree() + if !countRef { + store = newRegionTree() + } else { + store = newRegionTreeWithCountRef() + } peersMap[storeID] = store } store.update(item, false) @@ -1106,17 +1127,17 @@ func (r *RegionsInfo) UpdateSubTreeOrderInsensitive(region *RegionInfo) { storeID := peer.GetStoreId() if peer.GetId() == region.leader.GetId() { // Add leader peer to leaders. - setPeer(r.leaders, storeID, item) + setPeer(r.leaders, storeID, item, true) } else { // Add follower peer to followers. - setPeer(r.followers, storeID, item) + setPeer(r.followers, storeID, item, false) } } setPeers := func(peersMap map[uint64]*regionTree, peers []*metapb.Peer) { for _, peer := range peers { storeID := peer.GetStoreId() - setPeer(peersMap, storeID, item) + setPeer(peersMap, storeID, item, false) } } // Add to learners. @@ -1284,10 +1305,14 @@ func (r *RegionsInfo) UpdateSubTree(region, origin *RegionInfo, overlaps []*Regi r.subRegions[region.GetID()] = item // It has been removed and all information needs to be updated again. // Set peers then. - setPeer := func(peersMap map[uint64]*regionTree, storeID uint64, item *regionItem) { + setPeer := func(peersMap map[uint64]*regionTree, storeID uint64, item *regionItem, countRef bool) { store, ok := peersMap[storeID] if !ok { - store = newRegionTree() + if !countRef { + store = newRegionTree() + } else { + store = newRegionTreeWithCountRef() + } peersMap[storeID] = store } store.update(item, false) @@ -1298,17 +1323,17 @@ func (r *RegionsInfo) UpdateSubTree(region, origin *RegionInfo, overlaps []*Regi storeID := peer.GetStoreId() if peer.GetId() == region.leader.GetId() { // Add leader peer to leaders. - setPeer(r.leaders, storeID, item) + setPeer(r.leaders, storeID, item, true) } else { // Add follower peer to followers. - setPeer(r.followers, storeID, item) + setPeer(r.followers, storeID, item, false) } } setPeers := func(peersMap map[uint64]*regionTree, peers []*metapb.Peer) { for _, peer := range peers { storeID := peer.GetStoreId() - setPeer(peersMap, storeID, item) + setPeer(peersMap, storeID, item, false) } } // Add to learners. @@ -1516,6 +1541,60 @@ func (r *RegionsInfo) GetStoreRegions(storeID uint64) []*RegionInfo { return regions } +// SubTreeRegionType is the type of sub tree region. +type SubTreeRegionType string + +const ( + // AllInSubTree is all sub trees. + AllInSubTree SubTreeRegionType = "all" + // LeaderInSubTree is the leader sub tree. + LeaderInSubTree SubTreeRegionType = "leader" + // FollowerInSubTree is the follower sub tree. + FollowerInSubTree SubTreeRegionType = "follower" + // LearnerInSubTree is the learner sub tree. + LearnerInSubTree SubTreeRegionType = "learner" + // WitnessInSubTree is the witness sub tree. + WitnessInSubTree SubTreeRegionType = "witness" + // PendingPeerInSubTree is the pending peer sub tree. + PendingPeerInSubTree SubTreeRegionType = "pending" +) + +// GetStoreRegions gets all RegionInfo with a given storeID +func (r *RegionsInfo) GetStoreRegionsByTypeInSubTree(storeID uint64, typ SubTreeRegionType) ([]*RegionInfo, error) { + r.st.RLock() + var regions []*RegionInfo + switch typ { + case LeaderInSubTree: + if leaders, ok := r.leaders[storeID]; ok { + regions = leaders.scanRanges() + } + case FollowerInSubTree: + if followers, ok := r.followers[storeID]; ok { + regions = followers.scanRanges() + } + case LearnerInSubTree: + if learners, ok := r.learners[storeID]; ok { + regions = learners.scanRanges() + } + case WitnessInSubTree: + if witnesses, ok := r.witnesses[storeID]; ok { + regions = witnesses.scanRanges() + } + case PendingPeerInSubTree: + if pendingPeers, ok := r.pendingPeers[storeID]; ok { + regions = pendingPeers.scanRanges() + } + case AllInSubTree: + r.st.RUnlock() + return r.GetStoreRegions(storeID), nil + default: + return nil, errors.Errorf("unknown sub tree region type %v", typ) + } + + r.st.RUnlock() + return regions, nil +} + // GetStoreLeaderRegionSize get total size of store's leader regions func (r *RegionsInfo) GetStoreLeaderRegionSize(storeID uint64) int64 { r.st.RLock() diff --git a/pkg/core/region_test.go b/pkg/core/region_test.go index 88683968f3fd..43629fccda0e 100644 --- a/pkg/core/region_test.go +++ b/pkg/core/region_test.go @@ -874,9 +874,10 @@ func TestUpdateRegionEquivalence(t *testing.T) { ctx := ContextTODO() regionsOld.AtomicCheckAndPutRegion(ctx, item) // new way + newItem := item.Clone() ctx = ContextTODO() - regionsNew.CheckAndPutRootTree(ctx, item) - regionsNew.CheckAndPutSubTree(item) + regionsNew.CheckAndPutRootTree(ctx, newItem) + regionsNew.CheckAndPutSubTree(newItem) } checksEquivalence := func() { re.Equal(regionsOld.GetRegionCount([]byte(""), []byte("")), regionsNew.GetRegionCount([]byte(""), []byte(""))) @@ -884,6 +885,13 @@ func TestUpdateRegionEquivalence(t *testing.T) { checkRegions(re, regionsOld) checkRegions(re, regionsNew) + for _, r := range regionsOld.GetRegions() { + re.Equal(int32(2), r.GetRef(), fmt.Sprintf("inconsistent region %d", r.GetID())) + } + for _, r := range regionsNew.GetRegions() { + re.Equal(int32(2), r.GetRef(), fmt.Sprintf("inconsistent region %d", r.GetID())) + } + for i := 1; i <= storeNums; i++ { re.Equal(regionsOld.GetStoreRegionCount(uint64(i)), regionsNew.GetStoreRegionCount(uint64(i))) re.Equal(regionsOld.GetStoreLeaderCount(uint64(i)), regionsNew.GetStoreLeaderCount(uint64(i))) @@ -938,3 +946,78 @@ func generateTestRegions(count int, storeNum int) []*RegionInfo { } return items } + +func TestUpdateRegionEventualConsistency(t *testing.T) { + re := require.New(t) + regionsOld := NewRegionsInfo() + regionsNew := NewRegionsInfo() + i := 1 + storeNum := 5 + peer1 := &metapb.Peer{StoreId: uint64(i%storeNum + 1), Id: uint64(i*storeNum + 1)} + peer2 := &metapb.Peer{StoreId: uint64((i+1)%storeNum + 1), Id: uint64(i*storeNum + 2)} + peer3 := &metapb.Peer{StoreId: uint64((i+2)%storeNum + 1), Id: uint64(i*storeNum + 3)} + item := NewRegionInfo(&metapb.Region{ + Id: uint64(i + 1), + Peers: []*metapb.Peer{peer1, peer2, peer3}, + StartKey: []byte(fmt.Sprintf("%20d", i*10)), + EndKey: []byte(fmt.Sprintf("%20d", (i+1)*10)), + RegionEpoch: &metapb.RegionEpoch{ConfVer: 100, Version: 100}, + }, + peer1, + SetApproximateKeys(10), + SetApproximateSize(10), + ) + regionItemA := item + regionPendingItemA := regionItemA.Clone(WithPendingPeers([]*metapb.Peer{peer3})) + + regionItemB := regionItemA.Clone() + regionPendingItemB := regionItemB.Clone(WithPendingPeers([]*metapb.Peer{peer3})) + regionGuide := GenerateRegionGuideFunc(true) + + // Old way + { + ctx := ContextTODO() + regionsOld.AtomicCheckAndPutRegion(ctx, regionPendingItemA) + re.Equal(int32(2), regionPendingItemA.GetRef()) + // check new item + saveKV, saveCache, needSync := regionGuide(ctx, regionItemA, regionPendingItemA) + re.True(needSync) + re.True(saveCache) + re.False(saveKV) + // update cache + regionsOld.AtomicCheckAndPutRegion(ctx, regionItemA) + re.Equal(int32(2), regionItemA.GetRef()) + } + + // New way + { + // root tree part in order, and updated in order, updated regionPendingItemB first, then regionItemB + ctx := ContextTODO() + regionsNew.CheckAndPutRootTree(ctx, regionPendingItemB) + re.Equal(int32(1), regionPendingItemB.GetRef()) + ctx = ContextTODO() + regionsNew.CheckAndPutRootTree(ctx, regionItemB) + re.Equal(int32(1), regionItemB.GetRef()) + re.Equal(int32(0), regionPendingItemB.GetRef()) + + // subtree part missing order, updated regionItemB first, then regionPendingItemB + regionsNew.CheckAndPutSubTree(regionItemB) + re.Equal(int32(2), regionItemB.GetRef()) + re.Equal(int32(0), regionPendingItemB.GetRef()) + regionsNew.UpdateSubTreeOrderInsensitive(regionPendingItemB) + re.Equal(int32(1), regionItemB.GetRef()) + re.Equal(int32(1), regionPendingItemB.GetRef()) + + // heartbeat again, no need updates root tree + saveKV, saveCache, needSync := regionGuide(ctx, regionItemB, regionItemB) + re.False(needSync) + re.False(saveCache) + re.False(saveKV) + + // but need update sub tree again + item := regionsNew.GetRegion(regionItemB.GetID()) + re.Equal(int32(1), item.GetRef()) + regionsNew.CheckAndPutSubTree(item) + re.Equal(int32(2), item.GetRef()) + } +} diff --git a/pkg/core/region_tree.go b/pkg/core/region_tree.go index 8c928f391eb3..6c3c71c51588 100644 --- a/pkg/core/region_tree.go +++ b/pkg/core/region_tree.go @@ -69,6 +69,8 @@ type regionTree struct { totalWriteKeysRate float64 // count the number of regions that not loaded from storage. notFromStorageRegionsCnt int + // count reference of RegionInfo + countRef bool } func newRegionTree() *regionTree { @@ -81,6 +83,17 @@ func newRegionTree() *regionTree { } } +func newRegionTreeWithCountRef() *regionTree { + return ®ionTree{ + tree: btree.NewG[*regionItem](defaultBTreeDegree), + totalSize: 0, + totalWriteBytesRate: 0, + totalWriteKeysRate: 0, + notFromStorageRegionsCnt: 0, + countRef: true, + } +} + func (t *regionTree) length() int { if t == nil { return 0 @@ -140,6 +153,9 @@ func (t *regionTree) update(item *regionItem, withOverlaps bool, overlaps ...*re t.tree.Delete(old) } t.tree.ReplaceOrInsert(item) + if t.countRef { + item.RegionInfo.IncRef() + } result := make([]*RegionInfo, len(overlaps)) for i, overlap := range overlaps { old := overlap.RegionInfo @@ -155,6 +171,9 @@ func (t *regionTree) update(item *regionItem, withOverlaps bool, overlaps ...*re if !old.LoadedFromStorage() { t.notFromStorageRegionsCnt-- } + if t.countRef { + old.DecRef() + } } return result @@ -180,6 +199,10 @@ func (t *regionTree) updateStat(origin *RegionInfo, region *RegionInfo) { if !origin.LoadedFromStorage() && region.LoadedFromStorage() { t.notFromStorageRegionsCnt-- } + if t.countRef { + origin.DecRef() + region.IncRef() + } } // remove removes a region if the region is in the tree. @@ -199,6 +222,9 @@ func (t *regionTree) remove(region *RegionInfo) { regionWriteBytesRate, regionWriteKeysRate := result.GetWriteRate() t.totalWriteBytesRate -= regionWriteBytesRate t.totalWriteKeysRate -= regionWriteKeysRate + if t.countRef { + result.RegionInfo.DecRef() + } if !region.LoadedFromStorage() { t.notFromStorageRegionsCnt-- } diff --git a/pkg/core/region_tree_test.go b/pkg/core/region_tree_test.go index f4ef6cb67b3b..3f2ca0c1fb8f 100644 --- a/pkg/core/region_tree_test.go +++ b/pkg/core/region_tree_test.go @@ -159,6 +159,7 @@ func TestRegionTree(t *testing.T) { updateNewItem(tree, regionA) updateNewItem(tree, regionC) re.Nil(tree.overlaps(newRegionItem([]byte("b"), []byte("c")))) + re.Equal(regionC, tree.overlaps(newRegionItem([]byte("c"), []byte("d")))[0].RegionInfo) re.Equal(regionC, tree.overlaps(newRegionItem([]byte("a"), []byte("cc")))[1].RegionInfo) re.Nil(tree.search([]byte{})) re.Equal(regionA, tree.search([]byte("a"))) diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index ba3323f4fb10..74612dbe8ded 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -619,6 +619,16 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c ratelimit.WithTaskName(ratelimit.ObserveRegionStatsAsync), ) } + // region is not updated to the subtree. + if origin.GetRef() < 2 { + ctx.TaskRunner.RunTask( + ctx, + core.ExtraTaskOpts(ctx, core.UpdateSubTree), + func(_ context.Context) { + c.CheckAndPutSubTree(region) + }, + ) + } return nil } tracer.OnSaveCacheBegin() diff --git a/pkg/response/store.go b/pkg/response/store.go index 1efe11bfb39c..8bff1e75e42d 100644 --- a/pkg/response/store.go +++ b/pkg/response/store.go @@ -64,6 +64,7 @@ type StoreStatus struct { RegionSize int64 `json:"region_size"` LearnerCount int `json:"learner_count,omitempty"` WitnessCount int `json:"witness_count,omitempty"` + PendingPeerCount int `json:"pending_peer_count,omitempty"` SlowScore uint64 `json:"slow_score,omitempty"` SlowTrend *SlowTrend `json:"slow_trend,omitempty"` SendingSnapCount uint32 `json:"sending_snap_count,omitempty"` @@ -117,6 +118,7 @@ func BuildStoreInfo(opt *sc.ScheduleConfig, store *core.StoreInfo) *StoreInfo { SlowTrend: slowTrend, SendingSnapCount: store.GetSendingSnapCount(), ReceivingSnapCount: store.GetReceivingSnapCount(), + PendingPeerCount: store.GetPendingPeerCount(), IsBusy: store.IsBusy(), }, } diff --git a/pkg/statistics/region_collection.go b/pkg/statistics/region_collection.go index cb0de6f601b8..565597b4efb3 100644 --- a/pkg/statistics/region_collection.go +++ b/pkg/statistics/region_collection.go @@ -158,14 +158,14 @@ func (r *RegionStatistics) RegionStatsNeedUpdate(region *core.RegionInfo) bool { region.IsOversized(int64(r.conf.GetRegionMaxSize()), int64(r.conf.GetRegionMaxKeys())) { return true } - // expected to be zero for below type - if r.IsRegionStatsType(regionID, PendingPeer) && len(region.GetPendingPeers()) == 0 { + + if r.IsRegionStatsType(regionID, PendingPeer) != (len(region.GetPendingPeers()) != 0) { return true } - if r.IsRegionStatsType(regionID, DownPeer) && len(region.GetDownPeers()) == 0 { + if r.IsRegionStatsType(regionID, DownPeer) != (len(region.GetDownPeers()) != 0) { return true } - if r.IsRegionStatsType(regionID, LearnerPeer) && len(region.GetLearners()) == 0 { + if r.IsRegionStatsType(regionID, LearnerPeer) != (len(region.GetLearners()) != 0) { return true } diff --git a/server/api/region.go b/server/api/region.go index dac92f247ca9..974b5e4fa120 100644 --- a/server/api/region.go +++ b/server/api/region.go @@ -218,7 +218,17 @@ func (h *regionsHandler) GetStoreRegions(w http.ResponseWriter, r *http.Request) h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } - regions := rc.GetStoreRegions(uint64(id)) + // get type from query + typ := r.URL.Query().Get("type") + if len(typ) == 0 { + typ = string(core.AllInSubTree) + } + + regions, err := rc.GetStoreRegionsByTypeInSubTree(uint64(id), core.SubTreeRegionType(typ)) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } b, err := response.MarshalRegionsInfoJSON(r.Context(), regions) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index f6dae46972f8..050919c97a3b 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -1047,6 +1047,16 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio ratelimit.WithTaskName(ratelimit.ObserveRegionStatsAsync), ) } + // region is not updated to the subtree. + if origin.GetRef() < 2 { + ctx.TaskRunner.RunTask( + ctx, + core.ExtraTaskOpts(ctx, core.UpdateSubTree), + func(_ context.Context) { + c.CheckAndPutSubTree(region) + }, + ) + } return nil } failpoint.Inject("concurrentRegionHeartbeat", func() { @@ -1201,6 +1211,11 @@ func (c *RaftCluster) GetStoreRegions(storeID uint64) []*core.RegionInfo { return c.core.GetStoreRegions(storeID) } +// GetStoreRegions returns all regions' information with a given storeID. +func (c *RaftCluster) GetStoreRegionsByType(storeID uint64) []*core.RegionInfo { + return c.core.GetStoreRegions(storeID) +} + // RandLeaderRegions returns some random regions that has leader on the store. func (c *RaftCluster) RandLeaderRegions(storeID uint64, ranges []core.KeyRange) []*core.RegionInfo { return c.core.RandLeaderRegions(storeID, ranges) diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index 10be418c0294..dfe7a6980c78 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -26,6 +26,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" @@ -248,6 +249,40 @@ func TestLeaderTransferAndMoveCluster(t *testing.T) { wg.Wait() } +func TestGetTSAfterTransferLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 2) + re.NoError(err) + endpoints := runServer(re, cluster) + leader := cluster.WaitLeader() + re.NotEmpty(leader) + defer cluster.Destroy() + + cli := setupCli(ctx, re, endpoints, pd.WithCustomTimeoutOption(10*time.Second)) + defer cli.Close() + + var leaderSwitched atomic.Bool + cli.GetServiceDiscovery().AddServingURLSwitchedCallback(func() { + leaderSwitched.Store(true) + }) + err = cluster.GetServer(leader).ResignLeader() + re.NoError(err) + newLeader := cluster.WaitLeader() + re.NotEmpty(newLeader) + re.NotEqual(leader, newLeader) + leader = cluster.WaitLeader() + re.NotEmpty(leader) + err = cli.GetServiceDiscovery().CheckMemberChanged() + re.NoError(err) + + testutil.Eventually(re, leaderSwitched.Load) + // The leader stream must be updated after the leader switch is sensed by the client. + _, _, err = cli.GetTS(context.TODO()) + re.NoError(err) +} + func TestTSOAllocatorLeader(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) diff --git a/tools/pd-ctl/pdctl/command/region_command.go b/tools/pd-ctl/pdctl/command/region_command.go index d7e19967c7a0..3536b01a606a 100644 --- a/tools/pd-ctl/pdctl/command/region_command.go +++ b/tools/pd-ctl/pdctl/command/region_command.go @@ -486,6 +486,7 @@ func NewRegionWithStoreCommand() *cobra.Command { Short: "show the regions of a specific store", Run: showRegionWithStoreCommandFunc, } + r.Flags().String("type", "all", "the type of the regions, could be 'all', 'leader', 'learner' or 'pending'") return r } @@ -496,6 +497,8 @@ func showRegionWithStoreCommandFunc(cmd *cobra.Command, args []string) { } storeID := args[0] prefix := regionsStorePrefix + "/" + storeID + flagType := cmd.Flag("type") + prefix += "?type=" + flagType.Value.String() r, err := doRequest(cmd, prefix, http.MethodGet, http.Header{}) if err != nil { cmd.Printf("Failed to get regions with the given storeID: %s\n", err) diff --git a/tools/pd-ctl/tests/region/region_test.go b/tools/pd-ctl/tests/region/region_test.go index b328fd882864..2952e137f3b4 100644 --- a/tools/pd-ctl/tests/region/region_test.go +++ b/tools/pd-ctl/tests/region/region_test.go @@ -108,6 +108,11 @@ func TestRegion(t *testing.T) { ) defer cluster.Destroy() + getRegionsByType := func(storeID uint64, regionType core.SubTreeRegionType) []*core.RegionInfo { + regions, _ := leaderServer.GetRaftCluster().GetStoreRegionsByTypeInSubTree(storeID, regionType) + return regions + } + var testRegionsCases = []struct { args []string expect []*core.RegionInfo @@ -118,7 +123,12 @@ func TestRegion(t *testing.T) { {[]string{"region", "sibling", "2"}, leaderServer.GetAdjacentRegions(leaderServer.GetRegionInfoByID(2))}, // region store command {[]string{"region", "store", "1"}, leaderServer.GetStoreRegions(1)}, - {[]string{"region", "store", "1"}, []*core.RegionInfo{r1, r2, r3, r4}}, + {[]string{"region", "store", "1", "--type=leader"}, getRegionsByType(1, core.LeaderInSubTree)}, + {[]string{"region", "store", "1", "--type=follower"}, getRegionsByType(1, core.FollowerInSubTree)}, + {[]string{"region", "store", "1", "--type=learner"}, getRegionsByType(1, core.LearnerInSubTree)}, + {[]string{"region", "store", "1", "--type=witness"}, getRegionsByType(1, core.WitnessInSubTree)}, + {[]string{"region", "store", "1", "--type=pending"}, getRegionsByType(1, core.PendingPeerInSubTree)}, + {[]string{"region", "store", "1", "--type=all"}, []*core.RegionInfo{r1, r2, r3, r4}}, // region check extra-peer command {[]string{"region", "check", "extra-peer"}, []*core.RegionInfo{r1}}, // region check miss-peer command