diff --git a/metrics/metrics.go b/metrics/metrics.go index 5db99cd33..95fe3aa5c 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -41,64 +41,66 @@ import ( // Client metrics. var ( - TiKVTxnCmdHistogram *prometheus.HistogramVec - TiKVBackoffHistogram *prometheus.HistogramVec - TiKVSendReqHistogram *prometheus.HistogramVec - TiKVSendReqCounter *prometheus.CounterVec - TiKVSendReqTimeCounter *prometheus.CounterVec - TiKVRPCNetLatencyHistogram *prometheus.HistogramVec - TiKVCoprocessorHistogram *prometheus.HistogramVec - TiKVLockResolverCounter *prometheus.CounterVec - TiKVRegionErrorCounter *prometheus.CounterVec - TiKVTxnWriteKVCountHistogram prometheus.Histogram - TiKVTxnWriteSizeHistogram prometheus.Histogram - TiKVRawkvCmdHistogram *prometheus.HistogramVec - TiKVRawkvSizeHistogram *prometheus.HistogramVec - TiKVTxnRegionsNumHistogram *prometheus.HistogramVec - TiKVLoadSafepointCounter *prometheus.CounterVec - TiKVSecondaryLockCleanupFailureCounter *prometheus.CounterVec - TiKVRegionCacheCounter *prometheus.CounterVec - TiKVLoadRegionCacheHistogram *prometheus.HistogramVec - TiKVLocalLatchWaitTimeHistogram prometheus.Histogram - TiKVStatusDuration *prometheus.HistogramVec - TiKVStatusCounter *prometheus.CounterVec - TiKVBatchWaitDuration prometheus.Histogram - TiKVBatchSendLatency prometheus.Histogram - TiKVBatchWaitOverLoad prometheus.Counter - TiKVBatchPendingRequests *prometheus.HistogramVec - TiKVBatchRequests *prometheus.HistogramVec - TiKVBatchClientUnavailable prometheus.Histogram - TiKVBatchClientWaitEstablish prometheus.Histogram - TiKVBatchClientRecycle prometheus.Histogram - TiKVBatchRecvLatency *prometheus.HistogramVec - TiKVRangeTaskStats *prometheus.GaugeVec - TiKVRangeTaskPushDuration *prometheus.HistogramVec - TiKVTokenWaitDuration prometheus.Histogram - TiKVTxnHeartBeatHistogram *prometheus.HistogramVec - TiKVPessimisticLockKeysDuration prometheus.Histogram - TiKVTTLLifeTimeReachCounter prometheus.Counter - TiKVNoAvailableConnectionCounter prometheus.Counter - TiKVTwoPCTxnCounter *prometheus.CounterVec - TiKVAsyncCommitTxnCounter *prometheus.CounterVec - TiKVOnePCTxnCounter *prometheus.CounterVec - TiKVStoreLimitErrorCounter *prometheus.CounterVec - TiKVGRPCConnTransientFailureCounter *prometheus.CounterVec - TiKVPanicCounter *prometheus.CounterVec - TiKVForwardRequestCounter *prometheus.CounterVec - TiKVTSFutureWaitDuration prometheus.Histogram - TiKVSafeTSUpdateCounter *prometheus.CounterVec - TiKVMinSafeTSGapSeconds *prometheus.GaugeVec - TiKVReplicaSelectorFailureCounter *prometheus.CounterVec - TiKVRequestRetryTimesHistogram prometheus.Histogram - TiKVTxnCommitBackoffSeconds prometheus.Histogram - TiKVTxnCommitBackoffCount prometheus.Histogram - TiKVSmallReadDuration prometheus.Histogram - TiKVReadThroughput prometheus.Histogram - TiKVUnsafeDestroyRangeFailuresCounterVec *prometheus.CounterVec - TiKVPrewriteAssertionUsageCounter *prometheus.CounterVec - TiKVStaleReadCounter *prometheus.CounterVec - TiKVStaleReadReqCounter *prometheus.CounterVec - TiKVStaleReadBytes *prometheus.CounterVec + TiKVTxnCmdHistogram *prometheus.HistogramVec + TiKVBackoffHistogram *prometheus.HistogramVec + TiKVSendReqHistogram *prometheus.HistogramVec + TiKVSendReqCounter *prometheus.CounterVec + TiKVSendReqTimeCounter *prometheus.CounterVec + TiKVRPCNetLatencyHistogram *prometheus.HistogramVec + TiKVCoprocessorHistogram *prometheus.HistogramVec + TiKVLockResolverCounter *prometheus.CounterVec + TiKVRegionErrorCounter *prometheus.CounterVec + TiKVTxnWriteKVCountHistogram prometheus.Histogram + TiKVTxnWriteSizeHistogram prometheus.Histogram + TiKVRawkvCmdHistogram *prometheus.HistogramVec + TiKVRawkvSizeHistogram *prometheus.HistogramVec + TiKVTxnRegionsNumHistogram *prometheus.HistogramVec + TiKVLoadSafepointCounter *prometheus.CounterVec + TiKVSecondaryLockCleanupFailureCounter *prometheus.CounterVec + TiKVRegionCacheCounter *prometheus.CounterVec + TiKVLoadRegionCacheHistogram *prometheus.HistogramVec + TiKVLocalLatchWaitTimeHistogram prometheus.Histogram + TiKVStatusDuration *prometheus.HistogramVec + TiKVStatusCounter *prometheus.CounterVec + TiKVBatchWaitDuration prometheus.Histogram + TiKVBatchSendLatency prometheus.Histogram + TiKVBatchWaitOverLoad prometheus.Counter + TiKVBatchPendingRequests *prometheus.HistogramVec + TiKVBatchRequests *prometheus.HistogramVec + TiKVBatchClientUnavailable prometheus.Histogram + TiKVBatchClientWaitEstablish prometheus.Histogram + TiKVBatchClientRecycle prometheus.Histogram + TiKVBatchRecvLatency *prometheus.HistogramVec + TiKVRangeTaskStats *prometheus.GaugeVec + TiKVRangeTaskPushDuration *prometheus.HistogramVec + TiKVTokenWaitDuration prometheus.Histogram + TiKVTxnHeartBeatHistogram *prometheus.HistogramVec + TiKVPessimisticLockKeysDuration prometheus.Histogram + TiKVTTLLifeTimeReachCounter prometheus.Counter + TiKVNoAvailableConnectionCounter prometheus.Counter + TiKVTwoPCTxnCounter *prometheus.CounterVec + TiKVAsyncCommitTxnCounter *prometheus.CounterVec + TiKVOnePCTxnCounter *prometheus.CounterVec + TiKVStoreLimitErrorCounter *prometheus.CounterVec + TiKVGRPCConnTransientFailureCounter *prometheus.CounterVec + TiKVPanicCounter *prometheus.CounterVec + TiKVForwardRequestCounter *prometheus.CounterVec + TiKVTSFutureWaitDuration prometheus.Histogram + TiKVSafeTSUpdateCounter *prometheus.CounterVec + TiKVMinSafeTSGapSeconds *prometheus.GaugeVec + TiKVReplicaSelectorFailureCounter *prometheus.CounterVec + TiKVRequestRetryTimesHistogram prometheus.Histogram + TiKVTxnCommitBackoffSeconds prometheus.Histogram + TiKVTxnCommitBackoffCount prometheus.Histogram + TiKVSmallReadDuration prometheus.Histogram + TiKVReadThroughput prometheus.Histogram + TiKVUnsafeDestroyRangeFailuresCounterVec *prometheus.CounterVec + TiKVPrewriteAssertionUsageCounter *prometheus.CounterVec + TiKVStaleReadCounter *prometheus.CounterVec + TiKVStaleReadReqCounter *prometheus.CounterVec + TiKVStaleReadBytes *prometheus.CounterVec + TiKVValidateReadTSFromPDCount prometheus.Counter + TiKVLowResolutionTSOUpdateIntervalSecondsGauge prometheus.Gauge ) // Label constants. @@ -617,6 +619,22 @@ func initMetrics(namespace, subsystem string) { Help: "Counter of stale read requests bytes", }, []string{LblResult, LblDirection}) + TiKVValidateReadTSFromPDCount = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "validate_read_ts_from_pd_count", + Help: "Counter of validating read ts by getting a timestamp from PD", + }) + + TiKVLowResolutionTSOUpdateIntervalSecondsGauge = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "low_resolution_tso_update_interval_seconds", + Help: "The actual working update interval for the low resolution TSO. As there are adaptive mechanism internally, this value may differ from the config.", + }) + initShortcuts() } @@ -690,6 +708,8 @@ func RegisterMetrics() { prometheus.MustRegister(TiKVStaleReadCounter) prometheus.MustRegister(TiKVStaleReadReqCounter) prometheus.MustRegister(TiKVStaleReadBytes) + prometheus.MustRegister(TiKVValidateReadTSFromPDCount) + prometheus.MustRegister(TiKVLowResolutionTSOUpdateIntervalSecondsGauge) } // readCounter reads the value of a prometheus.Counter. diff --git a/oracle/oracle.go b/oracle/oracle.go index b90a950e4..7ace335ec 100644 --- a/oracle/oracle.go +++ b/oracle/oracle.go @@ -50,6 +50,12 @@ type Oracle interface { GetTimestampAsync(ctx context.Context, opt *Option) Future GetLowResolutionTimestamp(ctx context.Context, opt *Option) (uint64, error) GetLowResolutionTimestampAsync(ctx context.Context, opt *Option) Future + // GetStaleTimestamp generates a timestamp based on the recently fetched timestamp and the elapsed time since + // when that timestamp was fetched. The result is expected to be about `prevSecond` seconds before the current + // time. + // WARNING: This method does not guarantee whether the generated timestamp is legal for accessing the data. + // Neither is it safe to use it for verifying the legality of another calculated timestamp. + // Be sure to validate the timestamp before using it to access the data. GetStaleTimestamp(ctx context.Context, txnScope string, prevSecond uint64) (uint64, error) IsExpired(lockTimestamp, TTL uint64, opt *Option) bool UntilExpired(lockTimeStamp, TTL uint64, opt *Option) int64 @@ -57,6 +63,13 @@ type Oracle interface { GetExternalTimestamp(ctx context.Context) (uint64, error) SetExternalTimestamp(ctx context.Context, ts uint64) error + + // ValidateSnapshotReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts + // that has been allocated by the oracle, so that it's safe to use this ts to perform snapshot read, stale read, + // etc. + // Note that this method only checks the ts from the oracle's perspective. It doesn't check whether the snapshot + // has been GCed. + ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *Option) error } // Future is a future which promises to return a timestamp. diff --git a/oracle/oracles/local.go b/oracle/oracles/local.go index 336affb27..1e6b747c9 100644 --- a/oracle/oracles/local.go +++ b/oracle/oracles/local.go @@ -39,6 +39,7 @@ import ( "sync" "time" + "github.com/pingcap/errors" "github.com/tikv/client-go/v2/oracle" ) @@ -134,3 +135,14 @@ func (l *localOracle) SetExternalTimestamp(ctx context.Context, newTimestamp uin func (l *localOracle) GetExternalTimestamp(ctx context.Context) (uint64, error) { return l.getExternalTimestamp(ctx) } + +func (l *localOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { + currentTS, err := l.GetTimestamp(ctx, opt) + if err != nil { + return errors.Errorf("fail to validate read timestamp: %v", err) + } + if currentTS < readTS { + return errors.Errorf("cannot set read timestamp to a future time") + } + return nil +} diff --git a/oracle/oracles/mock.go b/oracle/oracles/mock.go index 5cf7867e6..183b4c2d6 100644 --- a/oracle/oracles/mock.go +++ b/oracle/oracles/mock.go @@ -122,6 +122,17 @@ func (o *MockOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *or return o.GetTimestampAsync(ctx, opt) } +func (o *MockOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { + currentTS, err := o.GetTimestamp(ctx, opt) + if err != nil { + return errors.Errorf("fail to validate read timestamp: %v", err) + } + if currentTS < readTS { + return errors.Errorf("cannot set read timestamp to a future time") + } + return nil +} + // IsExpired implements oracle.Oracle interface. func (o *MockOracle) IsExpired(lockTimestamp, TTL uint64, _ *oracle.Option) bool { o.RLock() diff --git a/oracle/oracles/pd.go b/oracle/oracles/pd.go index 123cd58aa..29d537c8d 100644 --- a/oracle/oracles/pd.go +++ b/oracle/oracles/pd.go @@ -36,6 +36,7 @@ package oracles import ( "context" + "fmt" "strings" "sync" "sync/atomic" @@ -48,18 +49,109 @@ import ( "github.com/tikv/client-go/v2/oracle" pd "github.com/tikv/pd/client" "go.uber.org/zap" + "golang.org/x/sync/singleflight" ) var _ oracle.Oracle = &pdOracle{} const slowDist = 30 * time.Millisecond +type adaptiveUpdateTSIntervalState int + +const ( + adaptiveUpdateTSIntervalStateNone adaptiveUpdateTSIntervalState = iota + // adaptiveUpdateTSIntervalStateNormal represents the state that the adaptive update ts interval is synced with the + // configuration without performing any automatic adjustment. + adaptiveUpdateTSIntervalStateNormal + // adaptiveUpdateTSIntervalStateAdapting represents the state that as there are recently some stale read / snapshot + // read operations requesting a short staleness (now - readTS is nearly or exceeds the current update interval), + // so that we automatically shrink the update interval. Otherwise, read operations may don't have low resolution ts + // that is new enough for checking the legality of the read ts, causing them have to fetch the latest ts from PD, + // which is time-consuming. + adaptiveUpdateTSIntervalStateAdapting + // adaptiveUpdateTSIntervalStateRecovering represents the state that the update ts interval have once been shrunk, + // to adapt to reads with short staleness, but there isn't any such read operations for a while, so that we + // gradually recover the update interval to the configured value. + adaptiveUpdateTSIntervalStateRecovering + // adaptiveUpdateTSIntervalStateUnadjustable represents the state that the user has configured a very short update + // interval, so that we don't have any space to automatically adjust it. + adaptiveUpdateTSIntervalStateUnadjustable +) + +func (s adaptiveUpdateTSIntervalState) String() string { + switch s { + case adaptiveUpdateTSIntervalStateNormal: + return "normal" + case adaptiveUpdateTSIntervalStateAdapting: + return "adapting" + case adaptiveUpdateTSIntervalStateRecovering: + return "recovering" + case adaptiveUpdateTSIntervalStateUnadjustable: + return "unadjustable" + default: + return fmt.Sprintf("unknown(%v)", int(s)) + } +} + +const ( + // minAllowedAdaptiveUpdateTSInterval is the lower bound of the adaptive update ts interval for avoiding an abnormal + // read operation causing the update interval to be too short. + minAllowedAdaptiveUpdateTSInterval = 500 * time.Millisecond + // adaptiveUpdateTSIntervalShrinkingPreserve is the duration that we additionally shrinks when adapting to a read + // operation that requires a short staleness. + adaptiveUpdateTSIntervalShrinkingPreserve = 100 * time.Millisecond + // adaptiveUpdateTSIntervalBlockRecoverThreshold is the threshold of the difference between the current update + // interval and the staleness the read operation request to prevent the update interval from recovering back to + // normal. + adaptiveUpdateTSIntervalBlockRecoverThreshold = 200 * time.Millisecond + // adaptiveUpdateTSIntervalRecoverPerSecond is the duration that the update interval should grow per second when + // recovering to normal state from adapting state. + adaptiveUpdateTSIntervalRecoverPerSecond = 20 * time.Millisecond + // adaptiveUpdateTSIntervalDelayBeforeRecovering is the duration that we should hold the current adaptive update + // interval before turning back to normal state. + adaptiveUpdateTSIntervalDelayBeforeRecovering = 5 * time.Minute +) + // pdOracle is an Oracle that uses a placement driver client as source. type pdOracle struct { c pd.Client // txn_scope (string) -> lastTSPointer (*lastTSOPointer) lastTSMap sync.Map quit chan struct{} + // The configured interval to update the low resolution ts. Set by SetLowResolutionTimestampUpdateInterval. + // For TiDB, this is directly controlled by the system variable `tidb_low_resolution_tso_update_interval`. + lastTSUpdateInterval atomic.Int64 + // The actual interval to update the low resolution ts. If the configured one is too large to satisfy the + // requirement of the stale read or snapshot read, the actual interval can be automatically set to a shorter + // value than lastTSUpdateInterval. + // This value is also possible to be updated by SetLowResolutionTimestampUpdateInterval, which may happen when + // user adjusting the update interval manually. + adaptiveLastTSUpdateInterval atomic.Int64 + + adaptiveUpdateIntervalState struct { + // The mutex to avoid racing between updateTS goroutine and SetLowResolutionTimestampUpdateInterval. + mu sync.Mutex + // The most recent time that a stale read / snapshot read requests a timestamp that is close enough to + // the current adaptive update interval. If there is such a request recently, the adaptive interval + // should avoid falling back to the original (configured) value. + // Stored in unix microseconds to make it able to be accessed atomically. + lastShortStalenessReadTime atomic.Int64 + // When someone requests need shrinking the update interval immediately, it sends the duration it expects to + // this channel. + shrinkIntervalCh chan time.Duration + + // Only accessed in updateTS goroutine. No need to use atomic value. + lastTick time.Time + // Represents a description about the current state. + state adaptiveUpdateTSIntervalState + } + + // When the low resolution ts is not new enough and there are many concurrent stane read / snapshot read + // operations that needs to validate the read ts, we can use this to avoid too many concurrent GetTS calls by + // reusing a result for different `ValidateSnapshotReadTS` calls. This can be done because that + // we don't require the ts for validation to be strictly the latest one. + // Note that the result can't be reused for different txnScopes. The txnScope is used as the key. + tsForValidation singleflight.Group } // lastTSO stores the last timestamp oracle gets from PD server and the local time when the TSO is fetched. @@ -92,18 +184,37 @@ func (p *lastTSOPointer) compareAndSwap(old, new *lastTSO) bool { return atomic.CompareAndSwapPointer(&p.p, unsafe.Pointer(old), unsafe.Pointer(new)) } +type PDOracleOptions struct { + // The duration to update the last ts, i.e., the low resolution ts. + UpdateInterval time.Duration + // Disable the background periodic update of the last ts. This is for test purposes only. + NoUpdateTS bool +} + // NewPdOracle create an Oracle that uses a pd client source. // Refer https://github.com/tikv/pd/blob/master/client/client.go for more details. // PdOracle mantains `lastTS` to store the last timestamp got from PD server. If // `GetTimestamp()` is not called after `updateInterval`, it will be called by // itself to keep up with the timestamp on PD server. -func NewPdOracle(pdClient pd.Client, updateInterval time.Duration) (oracle.Oracle, error) { +func NewPdOracle(pdClient pd.Client, options *PDOracleOptions) (oracle.Oracle, error) { + if options.UpdateInterval <= 0 { + return nil, fmt.Errorf("updateInterval must be > 0") + } + o := &pdOracle{ - c: pdClient, - quit: make(chan struct{}), + c: pdClient, + quit: make(chan struct{}), + lastTSUpdateInterval: atomic.Int64{}, } + o.adaptiveUpdateIntervalState.shrinkIntervalCh = make(chan time.Duration, 1) + o.lastTSUpdateInterval.Store(int64(options.UpdateInterval)) + o.adaptiveLastTSUpdateInterval.Store(int64(options.UpdateInterval)) + o.adaptiveUpdateIntervalState.lastTick = time.Now() + ctx := context.TODO() - go o.updateTS(ctx, updateInterval) + if !options.NoUpdateTS { + go o.updateTS(ctx) + } // Initialize the timestamp of the global txnScope by Get. _, err := o.GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) if err != nil { @@ -238,23 +349,180 @@ func (o *pdOracle) getLastTSWithArrivalTS(txnScope string) (*lastTSO, bool) { return last, true } -func (o *pdOracle) updateTS(ctx context.Context, interval time.Duration) { - ticker := time.NewTicker(interval) +func max(x, y time.Duration) time.Duration { + if x > y { + return x + } + return y +} + +func (o *pdOracle) nextUpdateInterval(now time.Time, requiredStaleness time.Duration) time.Duration { + o.adaptiveUpdateIntervalState.mu.Lock() + defer o.adaptiveUpdateIntervalState.mu.Unlock() + + configuredInterval := time.Duration(o.lastTSUpdateInterval.Load()) + prevAdaptiveUpdateInterval := time.Duration(o.adaptiveLastTSUpdateInterval.Load()) + lastReachDropThresholdTime := time.UnixMilli(o.adaptiveUpdateIntervalState.lastShortStalenessReadTime.Load()) + + currentAdaptiveUpdateInterval := prevAdaptiveUpdateInterval + + // Shortcut + const none = adaptiveUpdateTSIntervalStateNone + + // The following `checkX` functions checks whether it should transit to the X state. Returns + // a tuple representing (state, newInterval). + // When `checkX` returns a valid state, it means that the current situation matches the state. In this case, it + // also returns the new interval that should be used next. + // When it returns `none`, we need to check if it should transit to other states. For each call to + // nextUpdateInterval, if all attempts to `checkX` function returns false, it keeps the previous state unchanged. + + checkUnadjustable := func() (adaptiveUpdateTSIntervalState, time.Duration) { + // If the user has configured a very short interval, we don't have any space to adjust it. Just use + // the user's configured value directly. + if configuredInterval <= minAllowedAdaptiveUpdateTSInterval { + return adaptiveUpdateTSIntervalStateUnadjustable, configuredInterval + } + return none, 0 + } + + checkNormal := func() (adaptiveUpdateTSIntervalState, time.Duration) { + // If the current actual update interval is synced with the configured value, and it's not unadjustable state, + // then it's the normal state. + if configuredInterval > minAllowedAdaptiveUpdateTSInterval && currentAdaptiveUpdateInterval == configuredInterval { + return adaptiveUpdateTSIntervalStateNormal, currentAdaptiveUpdateInterval + } + return none, 0 + } + + checkAdapting := func() (adaptiveUpdateTSIntervalState, time.Duration) { + if requiredStaleness != 0 && requiredStaleness < currentAdaptiveUpdateInterval && currentAdaptiveUpdateInterval > minAllowedAdaptiveUpdateTSInterval { + // If we are calculating the interval because of a request that requires a shorter staleness, we shrink the + // update interval immediately to adapt to it. + // We shrink the update interval to a value slightly lower than the requested staleness to avoid potential + // frequent shrinking operations. But there's a lower bound to prevent loading ts too frequently. + newInterval := max(requiredStaleness-adaptiveUpdateTSIntervalShrinkingPreserve, minAllowedAdaptiveUpdateTSInterval) + return adaptiveUpdateTSIntervalStateAdapting, newInterval + } + + if currentAdaptiveUpdateInterval != configuredInterval && now.Sub(lastReachDropThresholdTime) < adaptiveUpdateTSIntervalDelayBeforeRecovering { + // There is a recent request that requires a short staleness. Keep the current adaptive interval. + // If it's not adapting state, it's possible that it's previously in recovering state, and it stops recovering + // as there is a new read operation requesting a short staleness. + return adaptiveUpdateTSIntervalStateAdapting, currentAdaptiveUpdateInterval + } + + return none, 0 + } + + checkRecovering := func() (adaptiveUpdateTSIntervalState, time.Duration) { + if currentAdaptiveUpdateInterval == configuredInterval || now.Sub(lastReachDropThresholdTime) < adaptiveUpdateTSIntervalDelayBeforeRecovering { + return none, 0 + } + + timeSinceLastTick := now.Sub(o.adaptiveUpdateIntervalState.lastTick) + newInterval := currentAdaptiveUpdateInterval + time.Duration(timeSinceLastTick.Seconds()*float64(adaptiveUpdateTSIntervalRecoverPerSecond)) + if newInterval > configuredInterval { + newInterval = configuredInterval + } + + return adaptiveUpdateTSIntervalStateRecovering, newInterval + } + + // Check the specified states in order, until the state becomes determined. + // If it's still undetermined after all checks, keep the previous state. + nextState := func(checkFuncs ...func() (adaptiveUpdateTSIntervalState, time.Duration)) time.Duration { + for _, f := range checkFuncs { + state, newInterval := f() + if state == none { + continue + } + + currentAdaptiveUpdateInterval = newInterval + + // If the final state is the recovering state, do an additional step to check whether it can go back to + // normal state immediately. + if state == adaptiveUpdateTSIntervalStateRecovering { + var nextState adaptiveUpdateTSIntervalState + nextState, newInterval = checkNormal() + if nextState != none { + state = nextState + currentAdaptiveUpdateInterval = newInterval + } + } + + o.adaptiveLastTSUpdateInterval.Store(int64(currentAdaptiveUpdateInterval)) + if o.adaptiveUpdateIntervalState.state != state { + logutil.BgLogger().Info("adaptive update ts interval state transition", + zap.Duration("configuredInterval", configuredInterval), + zap.Duration("prevAdaptiveUpdateInterval", prevAdaptiveUpdateInterval), + zap.Duration("newAdaptiveUpdateInterval", currentAdaptiveUpdateInterval), + zap.Duration("requiredStaleness", requiredStaleness), + zap.Stringer("prevState", o.adaptiveUpdateIntervalState.state), + zap.Stringer("newState", state)) + o.adaptiveUpdateIntervalState.state = state + } + + return currentAdaptiveUpdateInterval + } + return currentAdaptiveUpdateInterval + } + + var newInterval time.Duration + if requiredStaleness != 0 { + newInterval = nextState(checkUnadjustable, checkAdapting) + } else { + newInterval = nextState(checkUnadjustable, checkAdapting, checkNormal, checkRecovering) + } + + metrics.TiKVLowResolutionTSOUpdateIntervalSecondsGauge.Set(newInterval.Seconds()) + + return newInterval +} + +func (o *pdOracle) updateTS(ctx context.Context) { + currentInterval := time.Duration(o.lastTSUpdateInterval.Load()) + ticker := time.NewTicker(currentInterval) defer ticker.Stop() + + doUpdate := func(now time.Time) { + // Update the timestamp for each txnScope + o.lastTSMap.Range(func(key, _ interface{}) bool { + txnScope := key.(string) + ts, err := o.getTimestamp(ctx, txnScope) + if err != nil { + logutil.Logger(ctx).Error("updateTS error", zap.String("txnScope", txnScope), zap.Error(err)) + return true + } + o.setLastTS(ts, txnScope) + return true + }) + + o.adaptiveUpdateIntervalState.lastTick = now + } + for { select { - case <-ticker.C: - // Update the timestamp for each txnScope - o.lastTSMap.Range(func(key, _ interface{}) bool { - txnScope := key.(string) - ts, err := o.getTimestamp(ctx, txnScope) - if err != nil { - logutil.Logger(ctx).Error("updateTS error", zap.String("txnScope", txnScope), zap.Error(err)) - return true + case now := <-ticker.C: + doUpdate(now) + + newInterval := o.nextUpdateInterval(now, 0) + if newInterval != currentInterval { + currentInterval = newInterval + ticker.Reset(currentInterval) + } + + case requiredStaleness := <-o.adaptiveUpdateIntervalState.shrinkIntervalCh: + now := time.Now() + newInterval := o.nextUpdateInterval(now, requiredStaleness) + if newInterval != currentInterval { + currentInterval = newInterval + + if time.Since(o.adaptiveUpdateIntervalState.lastTick) >= currentInterval { + doUpdate(time.Now()) } - o.setLastTS(ts, txnScope) - return true - }) + + ticker.Reset(currentInterval) + } case <-o.quit: return } @@ -347,3 +615,84 @@ func (o *pdOracle) SetExternalTimestamp(ctx context.Context, ts uint64) error { func (o *pdOracle) GetExternalTimestamp(ctx context.Context) (uint64, error) { return o.c.GetExternalTimestamp(ctx) } + +func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Option) (uint64, error) { + ch := o.tsForValidation.DoChan(opt.TxnScope, func() (interface{}, error) { + metrics.TiKVValidateReadTSFromPDCount.Inc() + + // If the call that triggers the execution of this function is canceled by the context, other calls that are + // waiting for reusing the same result should not be canceled. So pass context.Background() instead of the + // current ctx. + res, err := o.GetTimestamp(context.Background(), opt) + return res, err + }) + select { + case <-ctx.Done(): + return 0, errors.WithStack(ctx.Err()) + case res := <-ch: + if res.Err != nil { + return 0, errors.WithStack(res.Err) + } + return res.Val.(uint64), nil + } +} + +func (o *pdOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { + latestTS, err := o.GetLowResolutionTimestamp(ctx, opt) + // If we fail to get latestTS or the readTS exceeds it, get a timestamp from PD to double-check. + // But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function + // loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls. + if err != nil || readTS > latestTS { + currentTS, err := o.getCurrentTSForValidation(ctx, opt) + if err != nil { + return errors.Errorf("fail to validate read timestamp: %v", err) + } + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now()) + if readTS > currentTS { + return errors.Errorf("cannot set read timestamp to a future time") + } + } else { + estimatedCurrentTS, err := o.getStaleTimestamp(opt.TxnScope, 0) + if err != nil { + logutil.Logger(ctx).Warn("failed to estimate current ts by getSlateTimestamp for auto-adjusting update low resolution ts interval", + zap.Error(err), zap.Uint64("readTS", readTS), zap.String("txnScope", opt.TxnScope)) + } else { + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, estimatedCurrentTS, time.Now()) + } + } + return nil +} + +func (o *pdOracle) adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS uint64, currentTS uint64, now time.Time) { + requiredStaleness := oracle.GetTimeFromTS(currentTS).Sub(oracle.GetTimeFromTS(readTS)) + + // Do not acquire the mutex, as here we only needs a rough check. + // So it's possible that we get inconsistent values from these two atomic fields, but it won't cause any problem. + currentUpdateInterval := time.Duration(o.adaptiveLastTSUpdateInterval.Load()) + + if requiredStaleness <= currentUpdateInterval+adaptiveUpdateTSIntervalBlockRecoverThreshold { + // Record the most recent time when there's a read operation requesting the staleness close enough to the + // current update interval. + nowMillis := now.UnixMilli() + last := o.adaptiveUpdateIntervalState.lastShortStalenessReadTime.Load() + if last < nowMillis { + // Do not retry if the CAS fails (which may happen when there are other goroutines updating it + // concurrently), as we don't actually need to set it strictly. + o.adaptiveUpdateIntervalState.lastShortStalenessReadTime.CompareAndSwap(last, nowMillis) + } + } + + if requiredStaleness <= currentUpdateInterval && currentUpdateInterval > minAllowedAdaptiveUpdateTSInterval { + // Considering system time / PD time drifts, it's possible that we get a non-positive value from the + // calculation. Make sure it's always positive before passing it to the updateTS goroutine. + // Note that `nextUpdateInterval` method expects the requiredStaleness is always non-zero when triggerred + // by this path. + requiredStaleness = max(requiredStaleness, time.Millisecond) + // Try to non-blocking send a signal to notify it to change the interval immediately. But if the channel is + // busy, it means that there's another concurrent call trying to update it. Just skip it in this case. + select { + case o.adaptiveUpdateIntervalState.shrinkIntervalCh <- requiredStaleness: + default: + } + } +} diff --git a/oracle/oracles/pd_test.go b/oracle/oracles/pd_test.go index 376e3fa5a..23f85c533 100644 --- a/oracle/oracles/pd_test.go +++ b/oracle/oracles/pd_test.go @@ -32,34 +32,35 @@ // See the License for the specific language governing permissions and // limitations under the License. -package oracles_test +package oracles import ( "context" "math" + "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/oracle/oracles" + pd "github.com/tikv/pd/client" ) func TestPDOracle_UntilExpired(t *testing.T) { lockAfter, lockExp := 10, 15 - o := oracles.NewEmptyPDOracle() + o := NewEmptyPDOracle() start := time.Now() - oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(start)) + SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(start)) lockTs := oracle.GoTimeToTS(start.Add(time.Duration(lockAfter)*time.Millisecond)) + 1 waitTs := o.UntilExpired(lockTs, uint64(lockExp), &oracle.Option{TxnScope: oracle.GlobalTxnScope}) assert.Equal(t, int64(lockAfter+lockExp), waitTs) } func TestPdOracle_GetStaleTimestamp(t *testing.T) { - o := oracles.NewEmptyPDOracle() + o := NewEmptyPDOracle() start := time.Now() - oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(start)) + SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(start)) ts, err := o.GetStaleTimestamp(context.Background(), oracle.GlobalTxnScope, 10) assert.Nil(t, err) assert.WithinDuration(t, start.Add(-10*time.Second), oracle.GetTimeFromTS(ts), 2*time.Second) @@ -73,9 +74,21 @@ func TestPdOracle_GetStaleTimestamp(t *testing.T) { assert.Regexp(t, ".*invalid prevSecond.*", err.Error()) } +// A mock for pd.Client that only returns global transaction scoped +// timestamps at the same physical time with increasing logical time +type MockPdClient struct { + pd.Client + + logicalTimestamp atomic.Int64 +} + +func (c *MockPdClient) GetTS(ctx context.Context) (int64, int64, error) { + return 0, c.logicalTimestamp.Add(1), nil +} + func TestNonFutureStaleTSO(t *testing.T) { - o := oracles.NewEmptyPDOracle() - oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(time.Now())) + o := NewEmptyPDOracle() + SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(time.Now())) for i := 0; i < 100; i++ { time.Sleep(10 * time.Millisecond) now := time.Now() @@ -84,7 +97,7 @@ func TestNonFutureStaleTSO(t *testing.T) { closeCh := make(chan struct{}) go func() { time.Sleep(100 * time.Microsecond) - oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(now)) + SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(now)) close(closeCh) }() CHECK: @@ -104,3 +117,320 @@ func TestNonFutureStaleTSO(t *testing.T) { } } } + +func TestAdaptiveUpdateTSInterval(t *testing.T) { + oracleInterface, err := NewPdOracle(&MockPdClient{}, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + NoUpdateTS: true, + }) + assert.NoError(t, err) + o := oracleInterface.(*pdOracle) + defer o.Close() + + now := time.Now() + + mockTS := func(beforeNow time.Duration) uint64 { + return oracle.ComposeTS(oracle.GetPhysical(now.Add(-beforeNow)), 1) + } + mustNotifyShrinking := func(expectedRequiredStaleness time.Duration) { + // Normally this channel should be checked in pdOracle.updateTS method. Here we are testing the layer below the + // updateTS method, so we just do this assert to ensure the message is sent to this channel. + select { + case requiredStaleness := <-o.adaptiveUpdateIntervalState.shrinkIntervalCh: + assert.Equal(t, expectedRequiredStaleness, requiredStaleness) + default: + assert.Fail(t, "expects notifying shrinking update interval immediately, but no message received") + } + } + mustNoNotify := func() { + select { + case <-o.adaptiveUpdateIntervalState.shrinkIntervalCh: + assert.Fail(t, "expects not notifying shrinking update interval immediately, but message was received") + default: + } + } + + now = now.Add(time.Second * 2) + assert.Equal(t, time.Second*2, o.nextUpdateInterval(now, 0)) + now = now.Add(time.Second * 2) + assert.Equal(t, time.Second*2, o.nextUpdateInterval(now, 0)) + assert.Equal(t, adaptiveUpdateTSIntervalStateNormal, o.adaptiveUpdateIntervalState.state) + + now = now.Add(time.Second) + // Simulate a read requesting a staleness larger than 2s, in which case nothing special will happen. + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(mockTS(time.Second*3), mockTS(0), now) + mustNoNotify() + assert.Equal(t, time.Second*2, o.nextUpdateInterval(now, 0)) + + now = now.Add(time.Second) + // Simulate a read requesting a staleness less than 2s, in which case it should trigger immediate shrinking on the + // update interval. + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(mockTS(time.Second), mockTS(0), now) + mustNotifyShrinking(time.Second) + expectedInterval := time.Second - adaptiveUpdateTSIntervalShrinkingPreserve + assert.Equal(t, expectedInterval, o.nextUpdateInterval(now, time.Second)) + assert.Equal(t, adaptiveUpdateTSIntervalStateAdapting, o.adaptiveUpdateIntervalState.state) + assert.Equal(t, now.UnixMilli(), o.adaptiveUpdateIntervalState.lastShortStalenessReadTime.Load()) + + // Let read with short staleness continue happening. + now = now.Add(adaptiveUpdateTSIntervalDelayBeforeRecovering / 2) + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(mockTS(time.Second), mockTS(0), now) + mustNoNotify() + assert.Equal(t, now.UnixMilli(), o.adaptiveUpdateIntervalState.lastShortStalenessReadTime.Load()) + + // The adaptiveUpdateTSIntervalDelayBeforeRecovering has not been elapsed since the last time there is a read with short + // staleness. The update interval won't start being reset at this time. + now = now.Add(adaptiveUpdateTSIntervalDelayBeforeRecovering/2 + time.Second) + o.adaptiveUpdateIntervalState.lastTick = now.Add(-time.Second) + assert.Equal(t, expectedInterval, o.nextUpdateInterval(now, 0)) + assert.Equal(t, adaptiveUpdateTSIntervalStateAdapting, o.adaptiveUpdateIntervalState.state) + + // The adaptiveUpdateTSIntervalDelayBeforeRecovering has been elapsed. + now = now.Add(adaptiveUpdateTSIntervalDelayBeforeRecovering / 2) + o.adaptiveUpdateIntervalState.lastTick = now.Add(-time.Second) + expectedInterval += adaptiveUpdateTSIntervalRecoverPerSecond + assert.InEpsilon(t, expectedInterval.Seconds(), o.nextUpdateInterval(now, 0).Seconds(), 1e-3) + assert.Equal(t, adaptiveUpdateTSIntervalStateRecovering, o.adaptiveUpdateIntervalState.state) + o.adaptiveUpdateIntervalState.lastTick = now + now = now.Add(time.Second * 2) + // No effect if the required staleness didn't trigger the threshold. + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(mockTS(expectedInterval+adaptiveUpdateTSIntervalBlockRecoverThreshold*2), mockTS(0), now) + mustNoNotify() + expectedInterval += adaptiveUpdateTSIntervalRecoverPerSecond * 2 + assert.InEpsilon(t, expectedInterval.Seconds(), o.nextUpdateInterval(now, 0).Seconds(), 1e-3) + assert.Equal(t, adaptiveUpdateTSIntervalStateRecovering, o.adaptiveUpdateIntervalState.state) + + // If there's a read operation requires a staleness that is close enough to the current adaptive update interval, + // then block the update interval from recovering. + o.adaptiveUpdateIntervalState.lastTick = now + now = now.Add(time.Second) + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(mockTS(expectedInterval+adaptiveUpdateTSIntervalBlockRecoverThreshold/2), mockTS(0), now) + mustNoNotify() + assert.InEpsilon(t, expectedInterval.Seconds(), o.nextUpdateInterval(now, 0).Seconds(), 1e-3) + assert.Equal(t, adaptiveUpdateTSIntervalStateAdapting, o.adaptiveUpdateIntervalState.state) + o.adaptiveUpdateIntervalState.lastTick = now + now = now.Add(time.Second) + assert.InEpsilon(t, expectedInterval.Seconds(), o.nextUpdateInterval(now, 0).Seconds(), 1e-3) + assert.Equal(t, adaptiveUpdateTSIntervalStateAdapting, o.adaptiveUpdateIntervalState.state) + + // Now adaptiveUpdateTSIntervalDelayBeforeRecovering + 1s has been elapsed. Continue recovering. + now = now.Add(adaptiveUpdateTSIntervalDelayBeforeRecovering) + o.adaptiveUpdateIntervalState.lastTick = now.Add(-time.Second) + expectedInterval += adaptiveUpdateTSIntervalRecoverPerSecond + assert.InEpsilon(t, expectedInterval.Seconds(), o.nextUpdateInterval(now, 0).Seconds(), 1e-3) + assert.Equal(t, adaptiveUpdateTSIntervalStateRecovering, o.adaptiveUpdateIntervalState.state) + + // Without any other interruption, the update interval will gradually recover to the same value as configured. + for { + o.adaptiveUpdateIntervalState.lastTick = now + now = now.Add(time.Second) + expectedInterval += adaptiveUpdateTSIntervalRecoverPerSecond + if expectedInterval >= time.Second*2 { + break + } + assert.InEpsilon(t, expectedInterval.Seconds(), o.nextUpdateInterval(now, 0).Seconds(), 1e-3) + assert.Equal(t, adaptiveUpdateTSIntervalStateRecovering, o.adaptiveUpdateIntervalState.state) + } + expectedInterval = time.Second * 2 + assert.Equal(t, expectedInterval, o.nextUpdateInterval(now, 0)) + assert.Equal(t, adaptiveUpdateTSIntervalStateNormal, o.adaptiveUpdateIntervalState.state) + + // Test adjusting configurations manually. + // When the adaptive update interval is not taking effect, the actual used update interval follows the change of + // the configuration immediately. + err = o.SetLowResolutionTimestampUpdateInterval(time.Second * 1) + assert.NoError(t, err) + assert.Equal(t, time.Second, time.Duration(o.adaptiveLastTSUpdateInterval.Load())) + assert.Equal(t, time.Second, o.nextUpdateInterval(now, 0)) + + err = o.SetLowResolutionTimestampUpdateInterval(time.Second * 2) + assert.NoError(t, err) + assert.Equal(t, time.Second*2, time.Duration(o.adaptiveLastTSUpdateInterval.Load())) + assert.Equal(t, time.Second*2, o.nextUpdateInterval(now, 0)) + + // If the adaptive update interval is taking effect, the configuration change doesn't immediately affect the actual + // update interval. + now = now.Add(time.Second) + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(mockTS(time.Second), mockTS(0), now) + mustNotifyShrinking(time.Second) + expectedInterval = time.Second - adaptiveUpdateTSIntervalShrinkingPreserve + assert.Equal(t, expectedInterval, o.nextUpdateInterval(now, time.Second)) + assert.Equal(t, adaptiveUpdateTSIntervalStateAdapting, o.adaptiveUpdateIntervalState.state) + err = o.SetLowResolutionTimestampUpdateInterval(time.Second * 3) + assert.NoError(t, err) + assert.Equal(t, expectedInterval, time.Duration(o.adaptiveLastTSUpdateInterval.Load())) + assert.Equal(t, expectedInterval, o.nextUpdateInterval(now, 0)) + err = o.SetLowResolutionTimestampUpdateInterval(time.Second) + assert.NoError(t, err) + assert.Equal(t, expectedInterval, time.Duration(o.adaptiveLastTSUpdateInterval.Load())) + assert.Equal(t, expectedInterval, o.nextUpdateInterval(now, 0)) + + // ...unless it's set to a value shorter than the current actual update interval. + err = o.SetLowResolutionTimestampUpdateInterval(time.Millisecond * 800) + assert.NoError(t, err) + assert.Equal(t, time.Millisecond*800, time.Duration(o.adaptiveLastTSUpdateInterval.Load())) + assert.Equal(t, time.Millisecond*800, o.nextUpdateInterval(now, 0)) + assert.Equal(t, adaptiveUpdateTSIntervalStateNormal, o.adaptiveUpdateIntervalState.state) + + // If the configured value is too short, the actual update interval won't be adaptive + err = o.SetLowResolutionTimestampUpdateInterval(minAllowedAdaptiveUpdateTSInterval / 2) + assert.NoError(t, err) + assert.Equal(t, minAllowedAdaptiveUpdateTSInterval/2, time.Duration(o.adaptiveLastTSUpdateInterval.Load())) + assert.Equal(t, minAllowedAdaptiveUpdateTSInterval/2, o.nextUpdateInterval(now, 0)) + assert.Equal(t, adaptiveUpdateTSIntervalStateUnadjustable, o.adaptiveUpdateIntervalState.state) +} + +func TestValidateSnapshotReadTS(t *testing.T) { + pdClient := MockPdClient{} + o, err := NewPdOracle(&pdClient, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + }) + assert.NoError(t, err) + defer o.Close() + + ctx := context.Background() + opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope} + ts, err := o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + assert.GreaterOrEqual(t, ts, uint64(1)) + + err = o.ValidateSnapshotReadTS(ctx, 1, opt) + assert.NoError(t, err) + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + // The readTS exceeds the latest ts, so it first fails the check with the low resolution ts. Then it fallbacks to + // the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass. + err = o.ValidateSnapshotReadTS(ctx, ts+1, opt) + assert.NoError(t, err) + // It can't pass if the readTS is newer than previous ts + 2. + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + err = o.ValidateSnapshotReadTS(ctx, ts+2, opt) + assert.Error(t, err) + + // Simulate other PD clients requests a timestamp. + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + pdClient.logicalTimestamp.Add(2) + err = o.ValidateSnapshotReadTS(ctx, ts+3, opt) + assert.NoError(t, err) +} + +type MockPDClientWithPause struct { + MockPdClient + mu sync.Mutex +} + +func (c *MockPDClientWithPause) GetTS(ctx context.Context) (int64, int64, error) { + c.mu.Lock() + defer c.mu.Unlock() + return c.MockPdClient.GetTS(ctx) +} + +func (c *MockPDClientWithPause) Pause() { + c.mu.Lock() +} + +func (c *MockPDClientWithPause) Resume() { + c.mu.Unlock() +} + +func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { + pdClient := &MockPDClientWithPause{} + o, err := NewPdOracle(pdClient, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + NoUpdateTS: true, + }) + assert.NoError(t, err) + defer o.Close() + + asyncValidate := func(ctx context.Context, readTS uint64) chan error { + ch := make(chan error, 1) + go func() { + err := o.ValidateSnapshotReadTS(ctx, readTS, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + ch <- err + }() + return ch + } + + noResult := func(ch chan error) { + select { + case <-ch: + assert.FailNow(t, "a ValidateSnapshotReadTS operation is not blocked while it's expected to be blocked") + default: + } + } + + cancelIndices := []int{-1, -1, 0, 1} + for i, ts := range []uint64{100, 200, 300, 400} { + // Note: the ts is the result that the next GetTS will return. Any validation with readTS <= ts should pass, otherwise fail. + + // We will cancel the cancelIndex-th validation call. This is for testing that canceling some of the calls + // doesn't affect other calls that are waiting + cancelIndex := cancelIndices[i] + + pdClient.Pause() + + results := make([]chan error, 0, 5) + + ctx, cancel := context.WithCancel(context.Background()) + + getCtx := func(index int) context.Context { + if cancelIndex == index { + return ctx + } else { + return context.Background() + } + } + + results = append(results, asyncValidate(getCtx(0), ts-2)) + results = append(results, asyncValidate(getCtx(1), ts+2)) + results = append(results, asyncValidate(getCtx(2), ts-1)) + results = append(results, asyncValidate(getCtx(3), ts+1)) + results = append(results, asyncValidate(getCtx(4), ts)) + + expectedSucceeds := []bool{true, false, true, false, true} + + time.Sleep(time.Millisecond * 50) + for _, ch := range results { + noResult(ch) + } + + cancel() + + for i, ch := range results { + if i == cancelIndex { + select { + case err := <-ch: + assert.Errorf(t, err, "index: %v", i) + assert.Containsf(t, err.Error(), "context canceled", "index: %v", i) + case <-time.After(time.Second): + assert.FailNowf(t, "expected result to be ready but still blocked", "index: %v", i) + } + } else { + noResult(ch) + } + } + + // ts will be the next ts returned to these validation calls. + pdClient.logicalTimestamp.Store(int64(ts - 1)) + pdClient.Resume() + for i, ch := range results { + if i == cancelIndex { + continue + } + + select { + case err = <-ch: + case <-time.After(time.Second): + assert.FailNowf(t, "expected result to be ready but still blocked", "index: %v", i) + } + if expectedSucceeds[i] { + assert.NoErrorf(t, err, "index: %v", i) + } else { + assert.Errorf(t, err, "index: %v", i) + assert.NotContainsf(t, err.Error(), "context canceled", "index: %v", i) + } + } + } +} diff --git a/tikv/kv.go b/tikv/kv.go index 811a1896c..019a6c2ae 100644 --- a/tikv/kv.go +++ b/tikv/kv.go @@ -184,7 +184,9 @@ func loadOption(store *KVStore, opt ...Option) { // NewKVStore creates a new TiKV store instance. func NewKVStore(uuid string, pdClient pd.Client, spkv SafePointKV, tikvclient Client, opt ...Option) (*KVStore, error) { - o, err := oracles.NewPdOracle(pdClient, time.Duration(oracleUpdateInterval)*time.Millisecond) + o, err := oracles.NewPdOracle(pdClient, &oracles.PDOracleOptions{ + UpdateInterval: time.Duration(oracleUpdateInterval) * time.Millisecond, + }) if err != nil { return nil, err }