From 98312a9d3d574691112f94a719c2ff267248ecc7 Mon Sep 17 00:00:00 2001 From: artem_danilov Date: Mon, 25 Nov 2024 21:27:21 -0800 Subject: [PATCH 1/8] introduce circuit breaker for region calls Signed-off-by: artem_danilov --- client/circuit_breaker/circuit_breaker.go | 288 ++++++++++++++++++ .../circuit_breaker/circuit_breaker_test.go | 204 +++++++++++++ client/client.go | 22 +- client/http/client.go | 40 ++- client/http/interface.go | 7 + client/http/request_info.go | 8 + client/inner_client.go | 22 +- client/metrics/metrics.go | 12 + client/opt/option.go | 18 +- 9 files changed, 608 insertions(+), 13 deletions(-) create mode 100644 client/circuit_breaker/circuit_breaker.go create mode 100644 client/circuit_breaker/circuit_breaker_test.go diff --git a/client/circuit_breaker/circuit_breaker.go b/client/circuit_breaker/circuit_breaker.go new file mode 100644 index 00000000000..38a88a7cb36 --- /dev/null +++ b/client/circuit_breaker/circuit_breaker.go @@ -0,0 +1,288 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package circuit_breaker + +import ( + "errors" + "fmt" + "github.com/prometheus/client_golang/prometheus" + m "github.com/tikv/pd/client/metrics" + "go.uber.org/zap" + "strings" + "sync" + "time" + + "github.com/pingcap/log" +) + +// ErrOpenState is returned when the CircuitBreaker is open or half-open with pending requests. +var ErrOpenState = errors.New("circuit breaker is open") + +// Overloading is a type describing service return value +type Overloading int + +const ( + // No means the service is not overloaded + No Overloading = iota + // Yes means the service is overloaded + Yes +) + +// Settings describes configuration for Circuit Breaker +type Settings struct { + // Defines the error rate threshold to trip the circuit breaker. + ErrorRateThresholdPct uint32 + // Defines the average qps over the `error_rate_window` that must be met before evaluating the error rate threshold. + MinQPSForOpen uint32 + // Defines how long to track errors before evaluating error_rate_threshold. + ErrorRateWindow time.Duration + // Defines how long to wait after circuit breaker is open before go to half-open state to send a probe request. + CoolDownInterval time.Duration + // Defines how many subsequent requests to test after cooldown period before fully close the circuit. + HalfOpenSuccessCount uint32 +} + +// CircuitBreaker is a state machine to prevent sending requests that are likely to fail. +type CircuitBreaker[T any] struct { + config *Settings + name string + + mutex sync.Mutex + state *State[T] + + successCounter prometheus.Counter + failureCounter prometheus.Counter + fastFailCounter prometheus.Counter +} + +// StateType is a type that represents a state of CircuitBreaker. +type StateType int + +// States of CircuitBreaker. +const ( + StateClosed StateType = iota + StateOpen + StateHalfOpen +) + +// String implements stringer interface. +func (s StateType) String() string { + switch s { + case StateClosed: + return "closed" + case StateOpen: + return "open" + case StateHalfOpen: + return "half-open" + default: + return fmt.Sprintf("unknown state: %d", s) + } +} + +var replacer = strings.NewReplacer(" ", "_", "-", "_") + +// NewCircuitBreaker returns a new CircuitBreaker configured with the given Settings. +func NewCircuitBreaker[T any](name string, st Settings) *CircuitBreaker[T] { + cb := new(CircuitBreaker[T]) + cb.name = name + cb.config = &st + cb.state = cb.newState(time.Now(), StateClosed) + + metricName := replacer.Replace(name) + cb.successCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "success") + cb.failureCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "failure") + cb.fastFailCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "fast_fail") + return cb +} + +// ChangeSettings changes the CircuitBreaker settings. +// The changes will be reflected only in the next evaluation window. +func (cb *CircuitBreaker[T]) ChangeSettings(apply func(config *Settings)) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + apply(cb.config) +} + +// Execute calls the given function if the CircuitBreaker is closed and returns the result of execution. +// Execute returns an error instantly if the CircuitBreaker is open. +// https://github.com/tikv/rfcs/blob/master/text/0115-circuit-breaker.md +func (cb *CircuitBreaker[T]) Execute(call func() (T, error, Overloading)) (T, error) { + result, err := cb.ExecuteAny(func() (interface{}, error, Overloading) { + res, err, open := call() + return res, err, open + }) + if result == nil { + // this branch is required to support primitive types like int, which can't be nil + var defaultValue T + return defaultValue, err + } else { + return result.(T), err + } +} + +// ExecuteAny is similar to Execute, but allows the caller to return any type of result. +func (cb *CircuitBreaker[T]) ExecuteAny(call func() (interface{}, error, Overloading)) (interface{}, error) { + state, err := cb.onRequest() + if err != nil { + var defaultValue interface{} + return defaultValue, err + } + + defer func() { + e := recover() + if e != nil { + cb.onResult(state, Yes) + panic(e) + } + }() + + result, err, open := call() + cb.onResult(state, open) + return result, err +} + +func (cb *CircuitBreaker[T]) onRequest() (*State[T], error) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + state, err := cb.state.onRequest(cb) + cb.state = state + return state, err +} + +func (cb *CircuitBreaker[T]) onResult(state *State[T], open Overloading) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + if cb.state == state { + state.onResult(open) + } // else the state moved forward so we don't need to update the counts +} + +type State[T any] struct { + stateType StateType + cb *CircuitBreaker[T] + end time.Time + + pendingCount uint32 + successCount uint32 + failureCount uint32 +} + +// newState creates a new State with the given configuration and reset all success/failure counters. +func (cb *CircuitBreaker[T]) newState(now time.Time, stateType StateType) *State[T] { + var end time.Time + var pendingCount uint32 + switch stateType { + case StateClosed: + end = now.Add(cb.config.ErrorRateWindow) + case StateOpen: + end = now.Add(cb.config.CoolDownInterval) + case StateHalfOpen: + // we transition to HalfOpen state on the first request after the cooldown period, + //so we start with 1 pending request + pendingCount = 1 + default: + panic("unknown state") + } + return &State[T]{ + cb: cb, + stateType: stateType, + pendingCount: pendingCount, + end: end, + } +} + +// onRequest transitions the state to the next state based on the current state and the previous requests results +// All state transitions happens at the request evaluation time only +// The implementation represents a state machine effectively +func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { + var now = time.Now() + switch s.stateType { + case StateClosed: + if s.end.Before(now) { + // ErrorRateWindow is over, let's evaluate the error rate + total := s.failureCount + s.successCount + observedErrorRatePct := s.failureCount * 100 / total + if s.cb.config.ErrorRateThresholdPct > 0 && total >= uint32(s.cb.config.ErrorRateWindow.Seconds())*s.cb.config.MinQPSForOpen && observedErrorRatePct >= s.cb.config.ErrorRateThresholdPct { + // the error threshold is breached, let's move to open state and start failing all requests + log.Error("Circuit breaker tripped. Starting to fail all requests", + zap.String("name", cb.name), + zap.Uint32("observedErrorRatePct", observedErrorRatePct), + zap.String("config", fmt.Sprintf("%+v", cb.config))) + cb.fastFailCounter.Inc() + return cb.newState(now, StateOpen), ErrOpenState + } else { + // the error threshold is not breached or there were not enough requests to evaluate it, + // continue in the closed state and allow all requests + return cb.newState(now, StateClosed), nil + } + } else { + // continue in closed state till ErrorRateWindow is over + return s, nil + } + case StateOpen: + if s.end.Before(now) { + // CoolDownInterval is over, it is time to transition to half-open state + log.Info("Circuit breaker cooldown period is over. Transitioning to half-open state to test the service", + zap.String("name", cb.name), + zap.String("config", fmt.Sprintf("%+v", cb.config))) + return cb.newState(now, StateHalfOpen), nil + } else { + // continue in the open state till CoolDownInterval is over + cb.fastFailCounter.Inc() + return s, ErrOpenState + } + case StateHalfOpen: + // do we need some expire time here in case of one of pending requests is stuck forever? + if s.failureCount > 0 { + // there were some failures during half-open state, let's go back to open state to wait a bit longer + log.Error("Circuit breaker goes from half-open to open again as errors persist and continue to fail all requests", + zap.String("name", cb.name), + zap.String("config", fmt.Sprintf("%+v", cb.config))) + cb.fastFailCounter.Inc() + return cb.newState(now, StateOpen), ErrOpenState + } else if s.successCount == s.cb.config.HalfOpenSuccessCount { + // all probe requests are succeeded, we can move to closed state and allow all requests + log.Info("Circuit breaker is closed. Start allowing all requests", + zap.String("name", cb.name), + zap.String("config", fmt.Sprintf("%+v", cb.config))) + return cb.newState(now, StateClosed), nil + } else if s.pendingCount < s.cb.config.HalfOpenSuccessCount { + // allow more probe requests and continue in half-open state + s.pendingCount++ + return s, nil + } else { + // continue in half-open state till all probe requests are done and fail all other requests for now + cb.fastFailCounter.Inc() + return s, ErrOpenState + } + default: + panic("unknown state") + } +} + +func (s *State[T]) onResult(open Overloading) { + switch open { + case No: + s.successCount++ + s.cb.successCounter.Inc() + case Yes: + s.failureCount++ + s.cb.fastFailCounter.Inc() + default: + panic("unknown state") + } +} diff --git a/client/circuit_breaker/circuit_breaker_test.go b/client/circuit_breaker/circuit_breaker_test.go new file mode 100644 index 00000000000..84b60c30f89 --- /dev/null +++ b/client/circuit_breaker/circuit_breaker_test.go @@ -0,0 +1,204 @@ +package circuit_breaker + +import ( + "errors" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +// advance emulate the state machine clock moves forward by the given duration +func (cb *CircuitBreaker[T]) advance(duration time.Duration) { + cb.state.end = cb.state.end.Add(-duration - 1) +} + +var settings = Settings{ + ErrorRateThresholdPct: 50, + MinQPSForOpen: 10, + ErrorRateWindow: 30 * time.Second, + CoolDownInterval: 10 * time.Second, + HalfOpenSuccessCount: 2, +} + +var minCountToOpen = int(settings.MinQPSForOpen * uint32(settings.ErrorRateWindow.Seconds())) + +func TestCircuitBreaker_Execute_Wrapper_Return_Values(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + originalError := errors.New("circuit breaker is open") + + result, err := cb.Execute(func() (int, error, Overloading) { + return 42, originalError, No + }) + re.Equal(err, originalError) + re.Equal(42, result) + + // same by interpret the result as overloading error + result, err = cb.Execute(func() (int, error, Overloading) { + return 42, originalError, Yes + }) + re.Equal(err, originalError) + re.Equal(42, result) +} + +func TestCircuitBreaker_OpenState(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + driveQPS(cb, minCountToOpen, Yes, re) + re.Equal(StateClosed, cb.state.stateType) + assertSucceeds(cb, re) // no error till ErrorRateWindow is finished + cb.advance(settings.ErrorRateWindow) + assertFastFail(cb, re) + re.Equal(StateOpen, cb.state.stateType) +} + +func TestCircuitBreaker_OpenState_Not_Enough_QPS(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + re.Equal(StateClosed, cb.state) + driveQPS(cb, minCountToOpen/2, Yes, re) + cb.advance(settings.ErrorRateWindow) + assertSucceeds(cb, re) + re.Equal(StateClosed, cb.state.stateType) +} + +func TestCircuitBreaker_OpenState_Not_Enough_Error_Rate(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + re.Equal(StateClosed, cb.state.stateType) + driveQPS(cb, minCountToOpen/4, Yes, re) + driveQPS(cb, minCountToOpen, No, re) + cb.advance(settings.ErrorRateWindow) + assertSucceeds(cb, re) + re.Equal(StateClosed, cb.state.stateType) +} + +func TestCircuitBreaker_Half_Open_To_Closed(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + re.Equal(StateClosed, cb.state.stateType) + driveQPS(cb, minCountToOpen, Yes, re) + cb.advance(settings.ErrorRateWindow) + assertFastFail(cb, re) + cb.advance(settings.CoolDownInterval) + assertSucceeds(cb, re) + assertSucceeds(cb, re) + re.Equal(StateHalfOpen, cb.state.stateType) + // state always transferred on the incoming request + assertSucceeds(cb, re) + re.Equal(StateClosed, cb.state.stateType) +} + +func TestCircuitBreaker_Half_Open_To_Open(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + re.Equal(StateClosed, cb.state.stateType) + driveQPS(cb, minCountToOpen, Yes, re) + cb.advance(settings.ErrorRateWindow) + assertFastFail(cb, re) + cb.advance(settings.CoolDownInterval) + assertSucceeds(cb, re) + re.Equal(StateHalfOpen, cb.state.stateType) + _, err := cb.Execute(func() (int, error, Overloading) { + return 42, nil, Yes // this trip circuit breaker again + }) + re.NoError(err) + re.Equal(StateHalfOpen, cb.state.stateType) + // state always transferred on the incoming request + assertFastFail(cb, re) + re.Equal(StateOpen, cb.state.stateType) +} + +func TestCircuitBreaker_Half_Open_Fail_Over_Pending_Count(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + re.Equal(StateClosed, cb.state.stateType) + driveQPS(cb, minCountToOpen, Yes, re) + cb.advance(settings.ErrorRateWindow) + assertFastFail(cb, re) + re.Equal(StateOpen, cb.state.stateType) + cb.advance(settings.CoolDownInterval) + + var started []chan bool + var waited []chan bool + var ended []chan bool + for i := 0; i < int(settings.HalfOpenSuccessCount); i++ { + start := make(chan bool) + wait := make(chan bool) + end := make(chan bool) + started = append(started, start) + waited = append(waited, wait) + ended = append(ended, end) + go func() { + defer func() { + end <- true + }() + _, err := cb.Execute(func() (int, error, Overloading) { + start <- true + <-wait + return 42, nil, No + }) + re.NoError(err) + }() + } + for i := 0; i < len(started); i++ { + <-started[i] + } + assertFastFail(cb, re) + re.Equal(StateHalfOpen, cb.state.stateType) + for i := 0; i < len(ended); i++ { + waited[i] <- true + <-ended[i] + } + assertSucceeds(cb, re) + re.Equal(StateClosed, cb.state.stateType) +} + +func TestCircuitBreaker_ChangeSettings(t *testing.T) { + re := require.New(t) + disabledSettings := settings + disabledSettings.ErrorRateThresholdPct = 0 + + cb := NewCircuitBreaker[int]("test_cb", disabledSettings) + driveQPS(cb, minCountToOpen, Yes, re) + cb.advance(settings.ErrorRateWindow) + assertSucceeds(cb, re) + re.Equal(StateClosed, cb.state.stateType) + + cb.ChangeSettings(func(config *Settings) { + config.ErrorRateThresholdPct = settings.ErrorRateThresholdPct + }) + re.Equal(settings.ErrorRateThresholdPct, cb.config.ErrorRateThresholdPct) + + driveQPS(cb, minCountToOpen, Yes, re) + cb.advance(settings.ErrorRateWindow) + assertFastFail(cb, re) + re.Equal(StateOpen, cb.state.stateType) +} + +func driveQPS(cb *CircuitBreaker[int], count int, overload Overloading, re *require.Assertions) { + for i := 0; i < count; i++ { + _, err := cb.Execute(func() (int, error, Overloading) { + return 42, nil, overload + }) + re.NoError(err) + } +} + +func assertFastFail(cb *CircuitBreaker[int], re *require.Assertions) { + var executed = false + _, err := cb.Execute(func() (int, error, Overloading) { + executed = true + return 42, nil, No + }) + re.Equal(err, ErrOpenState) + re.False(executed) +} + +func assertSucceeds(cb *CircuitBreaker[int], re *require.Assertions) { + result, err := cb.Execute(func() (int, error, Overloading) { + return 42, nil, No + }) + re.NoError(err) + re.Equal(result, 42) +} diff --git a/client/client.go b/client/client.go index 6781182a44b..74bb72cdf2e 100644 --- a/client/client.go +++ b/client/client.go @@ -18,6 +18,7 @@ import ( "context" "encoding/hex" "fmt" + "github.com/tikv/pd/client/circuit_breaker" "net/url" "runtime/trace" "strings" @@ -521,6 +522,12 @@ func (c *client) UpdateOption(option opt.DynamicOption, value any) error { return errors.New("[pd] invalid value type for TSOClientRPCConcurrency option, it should be int") } c.inner.option.SetTSOClientRPCConcurrency(value) + case opt.RegionMetadataCircuitBreakerSettings: + applySettingsChange, ok := value.(func(config *circuit_breaker.Settings)) + if !ok { + return errors.New("[pd] invalid value type for RegionMetadataCircuitBreakerSettings option, it should be pd.Settings") + } + c.inner.regionMetaCircuitBreaker.ChangeSettings(applySettingsChange) default: return errors.New("[pd] unsupported client option") } @@ -715,7 +722,10 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req) + resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, error, circuit_breaker.Overloading) { + region, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req) + return region, err, isOverloaded(err) + }) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) if protoClient == nil { @@ -755,7 +765,10 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req) + resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, error, circuit_breaker.Overloading) { + resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req) + return resp, err, isOverloaded(err) + }) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) if protoClient == nil { @@ -795,7 +808,10 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req) + resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, error, circuit_breaker.Overloading) { + resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req) + return resp, err, isOverloaded(err) + }) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) if protoClient == nil { diff --git a/client/http/client.go b/client/http/client.go index 9c522d87286..5195005d2d9 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -19,6 +19,7 @@ import ( "context" "crypto/tls" "encoding/json" + cb "github.com/tikv/pd/client/circuit_breaker" "io" "net/http" "time" @@ -60,9 +61,10 @@ type clientInner struct { // source is used to mark the source of the client creation, // it will also be used in the caller ID of the inner client. - source string - tlsConf *tls.Config - cli *http.Client + source string + tlsConf *tls.Config + cli *http.Client + regionMetaCircuitBreaker *cb.CircuitBreaker[*http.Response] requestCounter *prometheus.CounterVec executionDuration *prometheus.HistogramVec @@ -186,6 +188,19 @@ func noNeedRetry(statusCode int) bool { statusCode == http.StatusBadRequest } +func isOverloaded(resp *http.Response) cb.Overloading { + if resp == nil { + // request probably didn't reach target hence don't count it + return cb.Yes + } + switch resp.StatusCode { + case http.StatusRequestTimeout, http.StatusTooManyRequests, http.StatusServiceUnavailable: + return cb.Yes + default: + return cb.No + } +} + func (ci *clientInner) doRequest( ctx context.Context, serverURL string, reqInfo *requestInfo, @@ -213,9 +228,17 @@ func (ci *clientInner) doRequest( opt(req.Header) } req.Header.Set(xCallerIDKey, callerID) - start := time.Now() - resp, err := ci.cli.Do(req) + var resp *http.Response + if _, exists := regionRequestNames[reqInfo.name]; exists { + resp, err = ci.regionMetaCircuitBreaker.Execute(func() (*http.Response, error, cb.Overloading) { + resp, err := ci.cli.Do(req) + return resp, err, isOverloaded(resp) + }) + } else { + resp, err = ci.cli.Do(req) + } + if err != nil { ci.reqCounter(name, networkErrorStatus) log.Error("[pd] do http request failed", append(logFields, zap.Error(err))...) @@ -302,6 +325,13 @@ func WithMetrics( } } +// WithRegionMetaCircuitBreaker configures the client with circuit breaker for region meta calls +func WithRegionMetaCircuitBreaker(config cb.Settings) ClientOption { + return func(c *client) { + c.inner.regionMetaCircuitBreaker = cb.NewCircuitBreaker[*http.Response]("region-meta-http", config) + } +} + // NewClientWithServiceDiscovery creates a PD HTTP client with the given PD service discovery. func NewClientWithServiceDiscovery( source string, diff --git a/client/http/interface.go b/client/http/interface.go index f5cd1a38211..13e04776f72 100644 --- a/client/http/interface.go +++ b/client/http/interface.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "fmt" + cb "github.com/tikv/pd/client/circuit_breaker" "net/http" "strconv" "strings" @@ -63,6 +64,7 @@ type Client interface { GetClusterStatus(context.Context) (*ClusterState, error) GetStatus(context.Context) (*State, error) GetReplicateConfig(context.Context) (map[string]any, error) + UpdateCircuitBreakerSettings(apply func(config *pd.Settings)) /* Scheduler-related interfaces */ GetSchedulers(context.Context) ([]string, error) CreateScheduler(ctx context.Context, name string, storeID uint64) error @@ -541,6 +543,11 @@ func (c *client) GetReplicateConfig(ctx context.Context) (map[string]any, error) return config, nil } +// UpdateCircuitBreakerSettings updates config for circuit breaker +func (c *client) UpdateCircuitBreakerSettings(apply func(config *cb.Settings)) { + c.inner.regionMetaCircuitBreaker.ChangeSettings(apply) +} + // GetAllPlacementRuleBundles gets all placement rules bundles. func (c *client) GetAllPlacementRuleBundles(ctx context.Context) ([]*GroupBundle, error) { var bundles []*GroupBundle diff --git a/client/http/request_info.go b/client/http/request_info.go index 94f71c6186e..913ec24ef71 100644 --- a/client/http/request_info.go +++ b/client/http/request_info.go @@ -89,6 +89,14 @@ const ( DeleteGCSafePointName = "DeleteGCSafePoint" ) +var regionRequestNames = map[string]struct{}{ + getRegionByIDName: {}, + getRegionByKeyName: {}, + getRegionsName: {}, + getRegionsByKeyRangeName: {}, + getRegionsByStoreIDName: {}, +} + type requestInfo struct { callerID string name string diff --git a/client/inner_client.go b/client/inner_client.go index 467d6b66352..2d927410407 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -3,18 +3,21 @@ package pd import ( "context" "crypto/tls" + "google.golang.org/grpc/codes" "sync" "time" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + cb "github.com/tikv/pd/client/circuit_breaker" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" sd "github.com/tikv/pd/client/servicediscovery" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/status" ) const ( @@ -23,10 +26,11 @@ const ( ) type innerClient struct { - keyspaceID uint32 - svrUrls []string - pdSvcDiscovery sd.ServiceDiscovery - tokenDispatcher *tokenDispatcher + keyspaceID uint32 + svrUrls []string + pdSvcDiscovery sd.ServiceDiscovery + tokenDispatcher *tokenDispatcher + regionMetaCircuitBreaker *cb.CircuitBreaker[*pdpb.GetRegionResponse] // For service mode switching. serviceModeKeeper @@ -52,6 +56,7 @@ func (c *innerClient) init(updateKeyspaceIDCb sd.UpdateKeyspaceIDFunc) error { } return err } + c.regionMetaCircuitBreaker = cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", c.option.RegionMetaCircuitBreakerSettings) return nil } @@ -244,3 +249,12 @@ func (c *innerClient) dispatchTSORequestWithRetry(ctx context.Context) TSFuture } return req } + +func isOverloaded(err error) cb.Overloading { + switch status.Code(errors.Cause(err)) { + case codes.DeadlineExceeded, codes.Unavailable, codes.ResourceExhausted: + return cb.Yes + default: + return cb.No + } +} diff --git a/client/metrics/metrics.go b/client/metrics/metrics.go index da36217eb34..3a3199c74a6 100644 --- a/client/metrics/metrics.go +++ b/client/metrics/metrics.go @@ -56,6 +56,8 @@ var ( OngoingRequestCountGauge *prometheus.GaugeVec // EstimateTSOLatencyGauge is the gauge to indicate the estimated latency of TSO requests. EstimateTSOLatencyGauge *prometheus.GaugeVec + // CircuitBreakerCounters is a vector for different circuit breaker counters + CircuitBreakerCounters *prometheus.CounterVec ) func initMetrics(constLabels prometheus.Labels) { @@ -144,6 +146,15 @@ func initMetrics(constLabels prometheus.Labels) { Help: "Estimated latency of an RTT of getting TSO", ConstLabels: constLabels, }, []string{"stream"}) + + CircuitBreakerCounters = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "pd_client", + Subsystem: "request", + Name: "circuit_breaker_count", + Help: "Circuit Breaker counters", + ConstLabels: constLabels, + }, []string{"name", "success"}) } // CmdDurationXXX and CmdFailedDurationXXX are the durations of the client commands. @@ -259,4 +270,5 @@ func registerMetrics() { prometheus.MustRegister(TSOBatchSendLatency) prometheus.MustRegister(RequestForwarded) prometheus.MustRegister(EstimateTSOLatencyGauge) + prometheus.MustRegister(CircuitBreakerCounters) } diff --git a/client/opt/option.go b/client/opt/option.go index a9f6083484e..8be77e4660f 100644 --- a/client/opt/option.go +++ b/client/opt/option.go @@ -15,6 +15,7 @@ package opt import ( + cb "github.com/tikv/pd/client/circuit_breaker" "sync/atomic" "time" @@ -46,6 +47,8 @@ const ( EnableFollowerHandle // TSOClientRPCConcurrency controls the amount of ongoing TSO RPC requests at the same time in a single TSO client. TSOClientRPCConcurrency + // RegionMetadataCircuitBreakerSettings controls settings for circuit breaker for region metadata requests. + RegionMetadataCircuitBreakerSettings dynamicOptionCount ) @@ -65,7 +68,8 @@ type Option struct { // Dynamic options. dynamicOptions [dynamicOptionCount]atomic.Value - EnableTSOFollowerProxyCh chan struct{} + EnableTSOFollowerProxyCh chan struct{} + RegionMetaCircuitBreakerSettings cb.Settings } // NewOption creates a new PD client option with the default values set. @@ -145,6 +149,11 @@ func (o *Option) GetTSOClientRPCConcurrency() int { return o.dynamicOptions[TSOClientRPCConcurrency].Load().(int) } +// GetRegionMetadataCircuitBreakerSettings gets circuit breaker settings for PD region metadata calls. +func (o *Option) GetRegionMetadataCircuitBreakerSettings() cb.Settings { + return o.dynamicOptions[RegionMetadataCircuitBreakerSettings].Load().(cb.Settings) +} + // ClientOption configures client. type ClientOption func(*Option) @@ -199,6 +208,13 @@ func WithInitMetricsOption(initMetrics bool) ClientOption { } } +// WithRegionMetaCircuitBreaker configures the client with circuit breaker for region meta calls +func WithRegionMetaCircuitBreaker(config cb.Settings) ClientOption { + return func(op *Option) { + op.RegionMetaCircuitBreakerSettings = config + } +} + // GetStoreOp represents available options when getting stores. type GetStoreOp struct { ExcludeTombstone bool From 911cb56f4c4177fa44e9b15314a3d758d5b586a0 Mon Sep 17 00:00:00 2001 From: artem_danilov Date: Mon, 25 Nov 2024 22:06:10 -0800 Subject: [PATCH 2/8] add default circuit breaker Signed-off-by: artem_danilov --- client/circuit_breaker/circuit_breaker.go | 8 ++++++++ client/circuit_breaker/circuit_breaker_test.go | 8 +++----- client/http/client.go | 2 +- client/opt/option.go | 9 +++++---- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/client/circuit_breaker/circuit_breaker.go b/client/circuit_breaker/circuit_breaker.go index 38a88a7cb36..c24b894071b 100644 --- a/client/circuit_breaker/circuit_breaker.go +++ b/client/circuit_breaker/circuit_breaker.go @@ -53,6 +53,14 @@ type Settings struct { HalfOpenSuccessCount uint32 } +var AlwaysOpenSettings = Settings{ + ErrorRateThresholdPct: 0, // never trips + ErrorRateWindow: 10 * time.Second, // effectively results in testing for new settings every 10 seconds + MinQPSForOpen: 10, + CoolDownInterval: 10 * time.Second, + HalfOpenSuccessCount: 1, +} + // CircuitBreaker is a state machine to prevent sending requests that are likely to fail. type CircuitBreaker[T any] struct { config *Settings diff --git a/client/circuit_breaker/circuit_breaker_test.go b/client/circuit_breaker/circuit_breaker_test.go index 84b60c30f89..1c47b6c1327 100644 --- a/client/circuit_breaker/circuit_breaker_test.go +++ b/client/circuit_breaker/circuit_breaker_test.go @@ -156,12 +156,10 @@ func TestCircuitBreaker_Half_Open_Fail_Over_Pending_Count(t *testing.T) { func TestCircuitBreaker_ChangeSettings(t *testing.T) { re := require.New(t) - disabledSettings := settings - disabledSettings.ErrorRateThresholdPct = 0 - cb := NewCircuitBreaker[int]("test_cb", disabledSettings) - driveQPS(cb, minCountToOpen, Yes, re) - cb.advance(settings.ErrorRateWindow) + cb := NewCircuitBreaker[int]("test_cb", AlwaysOpenSettings) + driveQPS(cb, int(AlwaysOpenSettings.MinQPSForOpen*uint32(AlwaysOpenSettings.ErrorRateWindow.Seconds())), Yes, re) + cb.advance(AlwaysOpenSettings.ErrorRateWindow) assertSucceeds(cb, re) re.Equal(StateClosed, cb.state.stateType) diff --git a/client/http/client.go b/client/http/client.go index 5195005d2d9..31836c1f0c5 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -230,7 +230,7 @@ func (ci *clientInner) doRequest( req.Header.Set(xCallerIDKey, callerID) start := time.Now() var resp *http.Response - if _, exists := regionRequestNames[reqInfo.name]; exists { + if _, exists := regionRequestNames[reqInfo.name]; exists && ci.regionMetaCircuitBreaker != nil { resp, err = ci.regionMetaCircuitBreaker.Execute(func() (*http.Response, error, cb.Overloading) { resp, err := ci.cli.Do(req) return resp, err, isOverloaded(resp) diff --git a/client/opt/option.go b/client/opt/option.go index 8be77e4660f..73f7a3a440d 100644 --- a/client/opt/option.go +++ b/client/opt/option.go @@ -75,10 +75,11 @@ type Option struct { // NewOption creates a new PD client option with the default values set. func NewOption() *Option { co := &Option{ - Timeout: defaultPDTimeout, - MaxRetryTimes: maxInitClusterRetries, - EnableTSOFollowerProxyCh: make(chan struct{}, 1), - InitMetrics: true, + Timeout: defaultPDTimeout, + MaxRetryTimes: maxInitClusterRetries, + EnableTSOFollowerProxyCh: make(chan struct{}, 1), + InitMetrics: true, + RegionMetaCircuitBreakerSettings: cb.AlwaysOpenSettings, } co.dynamicOptions[MaxTSOBatchWaitInterval].Store(defaultMaxTSOBatchWaitInterval) From b7678b725fd85f0bd25f13a8981503e305fe32fd Mon Sep 17 00:00:00 2001 From: artem_danilov Date: Tue, 26 Nov 2024 20:00:47 -0800 Subject: [PATCH 3/8] fix merge issue interface.go:67:50: undefined: pd Signed-off-by: artem_danilov --- client/http/interface.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/http/interface.go b/client/http/interface.go index 13e04776f72..1de8f22d345 100644 --- a/client/http/interface.go +++ b/client/http/interface.go @@ -64,7 +64,7 @@ type Client interface { GetClusterStatus(context.Context) (*ClusterState, error) GetStatus(context.Context) (*State, error) GetReplicateConfig(context.Context) (map[string]any, error) - UpdateCircuitBreakerSettings(apply func(config *pd.Settings)) + UpdateCircuitBreakerSettings(apply func(config *cb.Settings)) /* Scheduler-related interfaces */ GetSchedulers(context.Context) ([]string, error) CreateScheduler(ctx context.Context, name string, storeID uint64) error From a95b198e3f3ba8355def44b1210237d67c9e9a19 Mon Sep 17 00:00:00 2001 From: artem_danilov Date: Thu, 28 Nov 2024 14:21:34 -0800 Subject: [PATCH 4/8] address PR comments Signed-off-by: artem_danilov --- .../circuit_breaker.go | 70 +++++++------------ .../circuit_breaker_test.go | 49 +++++++------ client/client.go | 17 ++--- client/errs/errno.go | 1 + client/http/client.go | 40 ++--------- client/http/interface.go | 7 -- client/http/request_info.go | 8 --- client/inner_client.go | 5 +- client/opt/option.go | 3 +- 9 files changed, 73 insertions(+), 127 deletions(-) rename client/{circuit_breaker => circuitbreaker}/circuit_breaker.go (85%) rename client/{circuit_breaker => circuitbreaker}/circuit_breaker_test.go (84%) diff --git a/client/circuit_breaker/circuit_breaker.go b/client/circuitbreaker/circuit_breaker.go similarity index 85% rename from client/circuit_breaker/circuit_breaker.go rename to client/circuitbreaker/circuit_breaker.go index c24b894071b..33e59517bcc 100644 --- a/client/circuit_breaker/circuit_breaker.go +++ b/client/circuitbreaker/circuit_breaker.go @@ -11,32 +11,31 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -package circuit_breaker +package circuitbreaker import ( - "errors" "fmt" - "github.com/prometheus/client_golang/prometheus" - m "github.com/tikv/pd/client/metrics" - "go.uber.org/zap" "strings" "sync" "time" + "github.com/tikv/pd/client/errs" + + "github.com/prometheus/client_golang/prometheus" + m "github.com/tikv/pd/client/metrics" + "go.uber.org/zap" + "github.com/pingcap/log" ) -// ErrOpenState is returned when the CircuitBreaker is open or half-open with pending requests. -var ErrOpenState = errors.New("circuit breaker is open") - // Overloading is a type describing service return value -type Overloading int +type Overloading bool const ( // No means the service is not overloaded - No Overloading = iota + No = false // Yes means the service is overloaded - Yes + Yes = true ) // Settings describes configuration for Circuit Breaker @@ -53,6 +52,7 @@ type Settings struct { HalfOpenSuccessCount uint32 } +// AlwaysOpenSettings is a configuration that never trips the circuit breaker. var AlwaysOpenSettings = Settings{ ErrorRateThresholdPct: 0, // never trips ErrorRateWindow: 10 * time.Second, // effectively results in testing for new settings every 10 seconds @@ -126,25 +126,10 @@ func (cb *CircuitBreaker[T]) ChangeSettings(apply func(config *Settings)) { // Execute calls the given function if the CircuitBreaker is closed and returns the result of execution. // Execute returns an error instantly if the CircuitBreaker is open. // https://github.com/tikv/rfcs/blob/master/text/0115-circuit-breaker.md -func (cb *CircuitBreaker[T]) Execute(call func() (T, error, Overloading)) (T, error) { - result, err := cb.ExecuteAny(func() (interface{}, error, Overloading) { - res, err, open := call() - return res, err, open - }) - if result == nil { - // this branch is required to support primitive types like int, which can't be nil - var defaultValue T - return defaultValue, err - } else { - return result.(T), err - } -} - -// ExecuteAny is similar to Execute, but allows the caller to return any type of result. -func (cb *CircuitBreaker[T]) ExecuteAny(call func() (interface{}, error, Overloading)) (interface{}, error) { +func (cb *CircuitBreaker[T]) Execute(call func() (T, Overloading, error)) (T, error) { state, err := cb.onRequest() if err != nil { - var defaultValue interface{} + var defaultValue T return defaultValue, err } @@ -156,8 +141,8 @@ func (cb *CircuitBreaker[T]) ExecuteAny(call func() (interface{}, error, Overloa } }() - result, err, open := call() - cb.onResult(state, open) + result, overloaded, err := call() + cb.onResult(state, overloaded) return result, err } @@ -179,6 +164,7 @@ func (cb *CircuitBreaker[T]) onResult(state *State[T], open Overloading) { } // else the state moved forward so we don't need to update the counts } +// State represents the state of CircuitBreaker. type State[T any] struct { stateType StateType cb *CircuitBreaker[T] @@ -200,7 +186,7 @@ func (cb *CircuitBreaker[T]) newState(now time.Time, stateType StateType) *State end = now.Add(cb.config.CoolDownInterval) case StateHalfOpen: // we transition to HalfOpen state on the first request after the cooldown period, - //so we start with 1 pending request + // so we start with 1 pending request pendingCount = 1 default: panic("unknown state") @@ -231,16 +217,14 @@ func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { zap.Uint32("observedErrorRatePct", observedErrorRatePct), zap.String("config", fmt.Sprintf("%+v", cb.config))) cb.fastFailCounter.Inc() - return cb.newState(now, StateOpen), ErrOpenState - } else { - // the error threshold is not breached or there were not enough requests to evaluate it, - // continue in the closed state and allow all requests - return cb.newState(now, StateClosed), nil + return cb.newState(now, StateOpen), errs.ErrCircuitBreakerOpen } - } else { - // continue in closed state till ErrorRateWindow is over - return s, nil + // the error threshold is not breached or there were not enough requests to evaluate it, + // continue in the closed state and allow all requests + return cb.newState(now, StateClosed), nil } + // continue in closed state till ErrorRateWindow is over + return s, nil case StateOpen: if s.end.Before(now) { // CoolDownInterval is over, it is time to transition to half-open state @@ -251,7 +235,7 @@ func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { } else { // continue in the open state till CoolDownInterval is over cb.fastFailCounter.Inc() - return s, ErrOpenState + return s, errs.ErrCircuitBreakerOpen } case StateHalfOpen: // do we need some expire time here in case of one of pending requests is stuck forever? @@ -261,7 +245,7 @@ func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { zap.String("name", cb.name), zap.String("config", fmt.Sprintf("%+v", cb.config))) cb.fastFailCounter.Inc() - return cb.newState(now, StateOpen), ErrOpenState + return cb.newState(now, StateOpen), errs.ErrCircuitBreakerOpen } else if s.successCount == s.cb.config.HalfOpenSuccessCount { // all probe requests are succeeded, we can move to closed state and allow all requests log.Info("Circuit breaker is closed. Start allowing all requests", @@ -275,7 +259,7 @@ func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { } else { // continue in half-open state till all probe requests are done and fail all other requests for now cb.fastFailCounter.Inc() - return s, ErrOpenState + return s, errs.ErrCircuitBreakerOpen } default: panic("unknown state") @@ -289,7 +273,7 @@ func (s *State[T]) onResult(open Overloading) { s.cb.successCounter.Inc() case Yes: s.failureCount++ - s.cb.fastFailCounter.Inc() + s.cb.failureCounter.Inc() default: panic("unknown state") } diff --git a/client/circuit_breaker/circuit_breaker_test.go b/client/circuitbreaker/circuit_breaker_test.go similarity index 84% rename from client/circuit_breaker/circuit_breaker_test.go rename to client/circuitbreaker/circuit_breaker_test.go index 1c47b6c1327..012021a1bbf 100644 --- a/client/circuit_breaker/circuit_breaker_test.go +++ b/client/circuitbreaker/circuit_breaker_test.go @@ -1,10 +1,13 @@ -package circuit_breaker +package circuitbreaker import ( "errors" - "github.com/stretchr/testify/require" "testing" "time" + + "github.com/tikv/pd/client/errs" + + "github.com/stretchr/testify/require" ) // advance emulate the state machine clock moves forward by the given duration @@ -27,15 +30,15 @@ func TestCircuitBreaker_Execute_Wrapper_Return_Values(t *testing.T) { cb := NewCircuitBreaker[int]("test_cb", settings) originalError := errors.New("circuit breaker is open") - result, err := cb.Execute(func() (int, error, Overloading) { - return 42, originalError, No + result, err := cb.Execute(func() (int, Overloading, error) { + return 42, No, originalError }) re.Equal(err, originalError) re.Equal(42, result) // same by interpret the result as overloading error - result, err = cb.Execute(func() (int, error, Overloading) { - return 42, originalError, Yes + result, err = cb.Execute(func() (int, Overloading, error) { + return 42, Yes, originalError }) re.Equal(err, originalError) re.Equal(42, result) @@ -55,7 +58,7 @@ func TestCircuitBreaker_OpenState(t *testing.T) { func TestCircuitBreaker_OpenState_Not_Enough_QPS(t *testing.T) { re := require.New(t) cb := NewCircuitBreaker[int]("test_cb", settings) - re.Equal(StateClosed, cb.state) + re.Equal(StateClosed, cb.state.stateType) driveQPS(cb, minCountToOpen/2, Yes, re) cb.advance(settings.ErrorRateWindow) assertSucceeds(cb, re) @@ -99,8 +102,8 @@ func TestCircuitBreaker_Half_Open_To_Open(t *testing.T) { cb.advance(settings.CoolDownInterval) assertSucceeds(cb, re) re.Equal(StateHalfOpen, cb.state.stateType) - _, err := cb.Execute(func() (int, error, Overloading) { - return 42, nil, Yes // this trip circuit breaker again + _, err := cb.Execute(func() (int, Overloading, error) { + return 42, Yes, nil // this trip circuit breaker again }) re.NoError(err) re.Equal(StateHalfOpen, cb.state.stateType) @@ -122,7 +125,7 @@ func TestCircuitBreaker_Half_Open_Fail_Over_Pending_Count(t *testing.T) { var started []chan bool var waited []chan bool var ended []chan bool - for i := 0; i < int(settings.HalfOpenSuccessCount); i++ { + for range int(settings.HalfOpenSuccessCount) { start := make(chan bool) wait := make(chan bool) end := make(chan bool) @@ -133,20 +136,20 @@ func TestCircuitBreaker_Half_Open_Fail_Over_Pending_Count(t *testing.T) { defer func() { end <- true }() - _, err := cb.Execute(func() (int, error, Overloading) { + _, err := cb.Execute(func() (int, Overloading, error) { start <- true <-wait - return 42, nil, No + return 42, No, nil }) re.NoError(err) }() } - for i := 0; i < len(started); i++ { + for i := range started { <-started[i] } assertFastFail(cb, re) re.Equal(StateHalfOpen, cb.state.stateType) - for i := 0; i < len(ended); i++ { + for i := range ended { waited[i] <- true <-ended[i] } @@ -175,9 +178,9 @@ func TestCircuitBreaker_ChangeSettings(t *testing.T) { } func driveQPS(cb *CircuitBreaker[int], count int, overload Overloading, re *require.Assertions) { - for i := 0; i < count; i++ { - _, err := cb.Execute(func() (int, error, Overloading) { - return 42, nil, overload + for range count { + _, err := cb.Execute(func() (int, Overloading, error) { + return 42, overload, nil }) re.NoError(err) } @@ -185,18 +188,18 @@ func driveQPS(cb *CircuitBreaker[int], count int, overload Overloading, re *requ func assertFastFail(cb *CircuitBreaker[int], re *require.Assertions) { var executed = false - _, err := cb.Execute(func() (int, error, Overloading) { + _, err := cb.Execute(func() (int, Overloading, error) { executed = true - return 42, nil, No + return 42, No, nil }) - re.Equal(err, ErrOpenState) + re.Equal(err, errs.ErrCircuitBreakerOpen) re.False(executed) } func assertSucceeds(cb *CircuitBreaker[int], re *require.Assertions) { - result, err := cb.Execute(func() (int, error, Overloading) { - return 42, nil, No + result, err := cb.Execute(func() (int, Overloading, error) { + return 42, No, nil }) re.NoError(err) - re.Equal(result, 42) + re.Equal(42, result) } diff --git a/client/client.go b/client/client.go index adb3d1a5890..93d1cfaf79b 100644 --- a/client/client.go +++ b/client/client.go @@ -18,13 +18,14 @@ import ( "context" "encoding/hex" "fmt" - "github.com/tikv/pd/client/circuit_breaker" "net/url" "runtime/trace" "strings" "sync" "time" + cb "github.com/tikv/pd/client/circuitbreaker" + "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -523,7 +524,7 @@ func (c *client) UpdateOption(option opt.DynamicOption, value any) error { } c.inner.option.SetTSOClientRPCConcurrency(value) case opt.RegionMetadataCircuitBreakerSettings: - applySettingsChange, ok := value.(func(config *circuit_breaker.Settings)) + applySettingsChange, ok := value.(func(config *cb.Settings)) if !ok { return errors.New("[pd] invalid value type for RegionMetadataCircuitBreakerSettings option, it should be pd.Settings") } @@ -722,9 +723,9 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, error, circuit_breaker.Overloading) { + resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) { region, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req) - return region, err, isOverloaded(err) + return region, isOverloaded(err), err }) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) @@ -765,9 +766,9 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, error, circuit_breaker.Overloading) { + resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) { resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req) - return resp, err, isOverloaded(err) + return resp, isOverloaded(err), err }) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) @@ -808,9 +809,9 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, error, circuit_breaker.Overloading) { + resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) { resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req) - return resp, err, isOverloaded(err) + return resp, isOverloaded(err), err }) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) diff --git a/client/errs/errno.go b/client/errs/errno.go index df8b677525a..25665f01017 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -70,6 +70,7 @@ var ( ErrClientGetServingEndpoint = errors.Normalize("get serving endpoint failed", errors.RFCCodeText("PD:client:ErrClientGetServingEndpoint")) ErrClientFindGroupByKeyspaceID = errors.Normalize("can't find keyspace group by keyspace id", errors.RFCCodeText("PD:client:ErrClientFindGroupByKeyspaceID")) ErrClientWatchGCSafePointV2Stream = errors.Normalize("watch gc safe point v2 stream failed", errors.RFCCodeText("PD:client:ErrClientWatchGCSafePointV2Stream")) + ErrCircuitBreakerOpen = errors.Normalize("circuit breaker is open", errors.RFCCodeText("PD:client:ErrCircuitBreakerOpen")) ) // grpcutil errors diff --git a/client/http/client.go b/client/http/client.go index 2b8f422a492..c813474fcf6 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -19,7 +19,6 @@ import ( "context" "crypto/tls" "encoding/json" - cb "github.com/tikv/pd/client/circuit_breaker" "io" "net/http" "time" @@ -61,10 +60,9 @@ type clientInner struct { // source is used to mark the source of the client creation, // it will also be used in the caller ID of the inner client. - source string - tlsConf *tls.Config - cli *http.Client - regionMetaCircuitBreaker *cb.CircuitBreaker[*http.Response] + source string + tlsConf *tls.Config + cli *http.Client requestCounter *prometheus.CounterVec executionDuration *prometheus.HistogramVec @@ -188,19 +186,6 @@ func noNeedRetry(statusCode int) bool { statusCode == http.StatusBadRequest } -func isOverloaded(resp *http.Response) cb.Overloading { - if resp == nil { - // request probably didn't reach target hence don't count it - return cb.Yes - } - switch resp.StatusCode { - case http.StatusRequestTimeout, http.StatusTooManyRequests, http.StatusServiceUnavailable: - return cb.Yes - default: - return cb.No - } -} - func (ci *clientInner) doRequest( ctx context.Context, serverURL string, reqInfo *requestInfo, @@ -228,17 +213,9 @@ func (ci *clientInner) doRequest( opt(req.Header) } req.Header.Set(xCallerIDKey, callerID) - start := time.Now() - var resp *http.Response - if _, exists := regionRequestNames[reqInfo.name]; exists && ci.regionMetaCircuitBreaker != nil { - resp, err = ci.regionMetaCircuitBreaker.Execute(func() (*http.Response, error, cb.Overloading) { - resp, err := ci.cli.Do(req) - return resp, err, isOverloaded(resp) - }) - } else { - resp, err = ci.cli.Do(req) - } + start := time.Now() + resp, err := ci.cli.Do(req) if err != nil { ci.reqCounter(name, networkErrorStatus) log.Error("[pd] do http request failed", append(logFields, zap.Error(err))...) @@ -325,13 +302,6 @@ func WithMetrics( } } -// WithRegionMetaCircuitBreaker configures the client with circuit breaker for region meta calls -func WithRegionMetaCircuitBreaker(config cb.Settings) ClientOption { - return func(c *client) { - c.inner.regionMetaCircuitBreaker = cb.NewCircuitBreaker[*http.Response]("region-meta-http", config) - } -} - // NewClientWithServiceDiscovery creates a PD HTTP client with the given PD service discovery. func NewClientWithServiceDiscovery( source string, diff --git a/client/http/interface.go b/client/http/interface.go index 511fa9b6336..772599e27fb 100644 --- a/client/http/interface.go +++ b/client/http/interface.go @@ -18,7 +18,6 @@ import ( "context" "encoding/json" "fmt" - cb "github.com/tikv/pd/client/circuit_breaker" "net/http" "strconv" "strings" @@ -64,7 +63,6 @@ type Client interface { GetClusterStatus(context.Context) (*ClusterState, error) GetStatus(context.Context) (*State, error) GetReplicateConfig(context.Context) (map[string]any, error) - UpdateCircuitBreakerSettings(apply func(config *cb.Settings)) /* Scheduler-related interfaces */ GetSchedulers(context.Context) ([]string, error) CreateScheduler(ctx context.Context, name string, storeID uint64) error @@ -543,11 +541,6 @@ func (c *client) GetReplicateConfig(ctx context.Context) (map[string]any, error) return config, nil } -// UpdateCircuitBreakerSettings updates config for circuit breaker -func (c *client) UpdateCircuitBreakerSettings(apply func(config *cb.Settings)) { - c.inner.regionMetaCircuitBreaker.ChangeSettings(apply) -} - // GetAllPlacementRuleBundles gets all placement rules bundles. func (c *client) GetAllPlacementRuleBundles(ctx context.Context) ([]*GroupBundle, error) { var bundles []*GroupBundle diff --git a/client/http/request_info.go b/client/http/request_info.go index a8910474fbb..1e3449b59a0 100644 --- a/client/http/request_info.go +++ b/client/http/request_info.go @@ -89,14 +89,6 @@ const ( DeleteGCSafePointName = "DeleteGCSafePoint" ) -var regionRequestNames = map[string]struct{}{ - getRegionByIDName: {}, - getRegionByKeyName: {}, - getRegionsName: {}, - getRegionsByKeyRangeName: {}, - getRegionsByStoreIDName: {}, -} - type requestInfo struct { callerID string name string diff --git a/client/inner_client.go b/client/inner_client.go index 7e2add9b2b9..ae15c763854 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -3,14 +3,15 @@ package pd import ( "context" "crypto/tls" - "google.golang.org/grpc/codes" "sync" "time" + "google.golang.org/grpc/codes" + "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" - cb "github.com/tikv/pd/client/circuit_breaker" + cb "github.com/tikv/pd/client/circuitbreaker" "github.com/tikv/pd/client/clients/tso" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" diff --git a/client/opt/option.go b/client/opt/option.go index 73f7a3a440d..944eb1ca827 100644 --- a/client/opt/option.go +++ b/client/opt/option.go @@ -15,10 +15,11 @@ package opt import ( - cb "github.com/tikv/pd/client/circuit_breaker" "sync/atomic" "time" + cb "github.com/tikv/pd/client/circuitbreaker" + "github.com/pingcap/errors" "github.com/prometheus/client_golang/prometheus" "google.golang.org/grpc" From 6c49334be91c4c237a14f6517a0d2f1bc0e2e827 Mon Sep 17 00:00:00 2001 From: artem_danilov Date: Fri, 29 Nov 2024 10:55:43 -0800 Subject: [PATCH 5/8] minor tests and metrics refactoring Signed-off-by: artem_danilov --- client/circuitbreaker/circuit_breaker.go | 36 ++++++++++++------- client/circuitbreaker/circuit_breaker_test.go | 33 ++++++++++++----- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/client/circuitbreaker/circuit_breaker.go b/client/circuitbreaker/circuit_breaker.go index 33e59517bcc..9798ca15b07 100644 --- a/client/circuitbreaker/circuit_breaker.go +++ b/client/circuitbreaker/circuit_breaker.go @@ -71,6 +71,7 @@ type CircuitBreaker[T any] struct { successCounter prometheus.Counter failureCounter prometheus.Counter + overloadCounter prometheus.Counter fastFailCounter prometheus.Counter } @@ -109,7 +110,8 @@ func NewCircuitBreaker[T any](name string, st Settings) *CircuitBreaker[T] { metricName := replacer.Replace(name) cb.successCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "success") - cb.failureCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "failure") + cb.failureCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "error") + cb.overloadCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "overload") cb.fastFailCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "fast_fail") return cb } @@ -129,6 +131,7 @@ func (cb *CircuitBreaker[T]) ChangeSettings(apply func(config *Settings)) { func (cb *CircuitBreaker[T]) Execute(call func() (T, Overloading, error)) (T, error) { state, err := cb.onRequest() if err != nil { + cb.fastFailCounter.Inc() var defaultValue T return defaultValue, err } @@ -136,12 +139,14 @@ func (cb *CircuitBreaker[T]) Execute(call func() (T, Overloading, error)) (T, er defer func() { e := recover() if e != nil { + cb.emitMetric(Yes, err) cb.onResult(state, Yes) panic(e) } }() result, overloaded, err := call() + cb.emitMetric(overloaded, err) cb.onResult(state, overloaded) return result, err } @@ -155,13 +160,24 @@ func (cb *CircuitBreaker[T]) onRequest() (*State[T], error) { return state, err } -func (cb *CircuitBreaker[T]) onResult(state *State[T], open Overloading) { +func (cb *CircuitBreaker[T]) onResult(state *State[T], overloaded Overloading) { cb.mutex.Lock() defer cb.mutex.Unlock() - if cb.state == state { - state.onResult(open) - } // else the state moved forward so we don't need to update the counts + // even if the circuit breaker already moved to a new state while the request was in progress, + // it is still ok to update the old state, but it is not relevant anymore + state.onResult(overloaded) +} + +func (cb *CircuitBreaker[T]) emitMetric(overloaded Overloading, err error) { + if err == nil { + cb.successCounter.Inc() + } else { + cb.failureCounter.Inc() + } + if overloaded { + cb.overloadCounter.Inc() + } } // State represents the state of CircuitBreaker. @@ -216,7 +232,6 @@ func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { zap.String("name", cb.name), zap.Uint32("observedErrorRatePct", observedErrorRatePct), zap.String("config", fmt.Sprintf("%+v", cb.config))) - cb.fastFailCounter.Inc() return cb.newState(now, StateOpen), errs.ErrCircuitBreakerOpen } // the error threshold is not breached or there were not enough requests to evaluate it, @@ -234,7 +249,6 @@ func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { return cb.newState(now, StateHalfOpen), nil } else { // continue in the open state till CoolDownInterval is over - cb.fastFailCounter.Inc() return s, errs.ErrCircuitBreakerOpen } case StateHalfOpen: @@ -244,7 +258,6 @@ func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { log.Error("Circuit breaker goes from half-open to open again as errors persist and continue to fail all requests", zap.String("name", cb.name), zap.String("config", fmt.Sprintf("%+v", cb.config))) - cb.fastFailCounter.Inc() return cb.newState(now, StateOpen), errs.ErrCircuitBreakerOpen } else if s.successCount == s.cb.config.HalfOpenSuccessCount { // all probe requests are succeeded, we can move to closed state and allow all requests @@ -258,7 +271,6 @@ func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { return s, nil } else { // continue in half-open state till all probe requests are done and fail all other requests for now - cb.fastFailCounter.Inc() return s, errs.ErrCircuitBreakerOpen } default: @@ -266,14 +278,12 @@ func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { } } -func (s *State[T]) onResult(open Overloading) { - switch open { +func (s *State[T]) onResult(overloaded Overloading) { + switch overloaded { case No: s.successCount++ - s.cb.successCounter.Inc() case Yes: s.failureCount++ - s.cb.failureCounter.Inc() default: panic("unknown state") } diff --git a/client/circuitbreaker/circuit_breaker_test.go b/client/circuitbreaker/circuit_breaker_test.go index 012021a1bbf..f506b857466 100644 --- a/client/circuitbreaker/circuit_breaker_test.go +++ b/client/circuitbreaker/circuit_breaker_test.go @@ -83,8 +83,10 @@ func TestCircuitBreaker_Half_Open_To_Closed(t *testing.T) { driveQPS(cb, minCountToOpen, Yes, re) cb.advance(settings.ErrorRateWindow) assertFastFail(cb, re) + re.Equal(StateOpen, cb.state.stateType) cb.advance(settings.CoolDownInterval) assertSucceeds(cb, re) + re.Equal(StateHalfOpen, cb.state.stateType) assertSucceeds(cb, re) re.Equal(StateHalfOpen, cb.state.stateType) // state always transferred on the incoming request @@ -99,6 +101,7 @@ func TestCircuitBreaker_Half_Open_To_Open(t *testing.T) { driveQPS(cb, minCountToOpen, Yes, re) cb.advance(settings.ErrorRateWindow) assertFastFail(cb, re) + re.Equal(StateOpen, cb.state.stateType) cb.advance(settings.CoolDownInterval) assertSucceeds(cb, re) re.Equal(StateHalfOpen, cb.state.stateType) @@ -112,20 +115,17 @@ func TestCircuitBreaker_Half_Open_To_Open(t *testing.T) { re.Equal(StateOpen, cb.state.stateType) } +// in half open state, circuit breaker will allow only HalfOpenSuccessCount pending and should fast fail all other request till HalfOpenSuccessCount requests is completed +// this test moves circuit breaker to the half open state and verifies that requests above HalfOpenSuccessCount are failing func TestCircuitBreaker_Half_Open_Fail_Over_Pending_Count(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) - re.Equal(StateClosed, cb.state.stateType) - driveQPS(cb, minCountToOpen, Yes, re) - cb.advance(settings.ErrorRateWindow) - assertFastFail(cb, re) - re.Equal(StateOpen, cb.state.stateType) - cb.advance(settings.CoolDownInterval) + cb := newCircuitBreakerMovedToHalfOpenState(re) + // the next request will move circuit breaker into the half open state var started []chan bool var waited []chan bool var ended []chan bool - for range int(settings.HalfOpenSuccessCount) { + for range settings.HalfOpenSuccessCount { start := make(chan bool) wait := make(chan bool) end := make(chan bool) @@ -144,17 +144,23 @@ func TestCircuitBreaker_Half_Open_Fail_Over_Pending_Count(t *testing.T) { re.NoError(err) }() } + // make sure all requests are started for i := range started { <-started[i] } + // validate that requests beyond HalfOpenSuccessCount are failing assertFastFail(cb, re) re.Equal(StateHalfOpen, cb.state.stateType) + // unblock pending requests and wait till they are completed for i := range ended { waited[i] <- true <-ended[i] } + // validate that circuit breaker moves to closed state assertSucceeds(cb, re) re.Equal(StateClosed, cb.state.stateType) + // make sure that after moving to open state all counters are reset + re.Equal(uint32(1), cb.state.successCount) } func TestCircuitBreaker_ChangeSettings(t *testing.T) { @@ -177,6 +183,17 @@ func TestCircuitBreaker_ChangeSettings(t *testing.T) { re.Equal(StateOpen, cb.state.stateType) } +func newCircuitBreakerMovedToHalfOpenState(re *require.Assertions) *CircuitBreaker[int] { + cb := NewCircuitBreaker[int]("test_cb", settings) + re.Equal(StateClosed, cb.state.stateType) + driveQPS(cb, minCountToOpen, Yes, re) + cb.advance(settings.ErrorRateWindow) + assertFastFail(cb, re) + re.Equal(StateOpen, cb.state.stateType) + cb.advance(settings.CoolDownInterval) + return cb +} + func driveQPS(cb *CircuitBreaker[int], count int, overload Overloading, re *require.Assertions) { for range count { _, err := cb.Execute(func() (int, Overloading, error) { From daffcf6a53e11784ae1fbd0d00a66dfeaea802e7 Mon Sep 17 00:00:00 2001 From: artem_danilov Date: Wed, 4 Dec 2024 21:13:17 -0800 Subject: [PATCH 6/8] addressed more comments Signed-off-by: artem_danilov --- client/circuitbreaker/circuit_breaker.go | 50 ++++++++++++------- client/circuitbreaker/circuit_breaker_test.go | 13 +++++ 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/client/circuitbreaker/circuit_breaker.go b/client/circuitbreaker/circuit_breaker.go index 9798ca15b07..f2f23e5f977 100644 --- a/client/circuitbreaker/circuit_breaker.go +++ b/client/circuitbreaker/circuit_breaker.go @@ -70,7 +70,7 @@ type CircuitBreaker[T any] struct { state *State[T] successCounter prometheus.Counter - failureCounter prometheus.Counter + errorCounter prometheus.Counter overloadCounter prometheus.Counter fastFailCounter prometheus.Counter } @@ -110,7 +110,7 @@ func NewCircuitBreaker[T any](name string, st Settings) *CircuitBreaker[T] { metricName := replacer.Replace(name) cb.successCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "success") - cb.failureCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "error") + cb.errorCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "error") cb.overloadCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "overload") cb.fastFailCounter = m.CircuitBreakerCounters.WithLabelValues(metricName, "fast_fail") return cb @@ -170,13 +170,16 @@ func (cb *CircuitBreaker[T]) onResult(state *State[T], overloaded Overloading) { } func (cb *CircuitBreaker[T]) emitMetric(overloaded Overloading, err error) { - if err == nil { + switch overloaded { + case No: cb.successCounter.Inc() - } else { - cb.failureCounter.Inc() - } - if overloaded { + case Yes: cb.overloadCounter.Inc() + default: + panic("unknown state") + } + if err != nil { + cb.errorCounter.Inc() } } @@ -216,23 +219,32 @@ func (cb *CircuitBreaker[T]) newState(now time.Time, stateType StateType) *State } // onRequest transitions the state to the next state based on the current state and the previous requests results +// The implementation represents a state machine for CircuitBreaker // All state transitions happens at the request evaluation time only -// The implementation represents a state machine effectively +// Circuit breaker start with a closed state, allows all requests to pass through and always lasts for a fixed duration of `Settings.ErrorRateWindow`. +// If `Settings.ErrorRateThresholdPct` is breached at the end of the window, then it moves to Open state, otherwise it moves to a new Closed state with a new window. +// Open state fails all request, it has a fixed duration of `Settings.CoolDownInterval` and always moves to HalfOpen state at the end of the interval. +// HalfOpen state does not have a fixed duration and lasts till `Settings.HalfOpenSuccessCount` are evaluated. +// If any of `Settings.HalfOpenSuccessCount` fails then it moves back to Open state, otherwise it moves to Closed state. func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { var now = time.Now() switch s.stateType { case StateClosed: - if s.end.Before(now) { + if now.After(s.end) { // ErrorRateWindow is over, let's evaluate the error rate - total := s.failureCount + s.successCount - observedErrorRatePct := s.failureCount * 100 / total - if s.cb.config.ErrorRateThresholdPct > 0 && total >= uint32(s.cb.config.ErrorRateWindow.Seconds())*s.cb.config.MinQPSForOpen && observedErrorRatePct >= s.cb.config.ErrorRateThresholdPct { - // the error threshold is breached, let's move to open state and start failing all requests - log.Error("Circuit breaker tripped. Starting to fail all requests", - zap.String("name", cb.name), - zap.Uint32("observedErrorRatePct", observedErrorRatePct), - zap.String("config", fmt.Sprintf("%+v", cb.config))) - return cb.newState(now, StateOpen), errs.ErrCircuitBreakerOpen + if s.cb.config.ErrorRateThresholdPct > 0 { // otherwise circuit breaker is disabled + total := s.failureCount + s.successCount + if total > 0 { + observedErrorRatePct := s.failureCount * 100 / total + if total >= uint32(s.cb.config.ErrorRateWindow.Seconds())*s.cb.config.MinQPSForOpen && observedErrorRatePct >= s.cb.config.ErrorRateThresholdPct { + // the error threshold is breached, let's move to open state and start failing all requests + log.Error("Circuit breaker tripped. Starting to fail all requests", + zap.String("name", cb.name), + zap.Uint32("observedErrorRatePct", observedErrorRatePct), + zap.String("config", fmt.Sprintf("%+v", cb.config))) + return cb.newState(now, StateOpen), errs.ErrCircuitBreakerOpen + } + } } // the error threshold is not breached or there were not enough requests to evaluate it, // continue in the closed state and allow all requests @@ -241,7 +253,7 @@ func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { // continue in closed state till ErrorRateWindow is over return s, nil case StateOpen: - if s.end.Before(now) { + if now.After(s.end) { // CoolDownInterval is over, it is time to transition to half-open state log.Info("Circuit breaker cooldown period is over. Transitioning to half-open state to test the service", zap.String("name", cb.name), diff --git a/client/circuitbreaker/circuit_breaker_test.go b/client/circuitbreaker/circuit_breaker_test.go index f506b857466..412c7fc43ab 100644 --- a/client/circuitbreaker/circuit_breaker_test.go +++ b/client/circuitbreaker/circuit_breaker_test.go @@ -1,3 +1,16 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. package circuitbreaker import ( From 2ebfc1e9b0603c67a49cc700ab3c7adcb506fce4 Mon Sep 17 00:00:00 2001 From: artem_danilov Date: Thu, 5 Dec 2024 12:09:53 -0800 Subject: [PATCH 7/8] fix merge confilicts issues Signed-off-by: artem_danilov --- client/opt/option.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client/opt/option.go b/client/opt/option.go index 2963b0f782a..ff79f104154 100644 --- a/client/opt/option.go +++ b/client/opt/option.go @@ -216,6 +216,8 @@ func WithInitMetricsOption(initMetrics bool) ClientOption { func WithRegionMetaCircuitBreaker(config cb.Settings) ClientOption { return func(op *Option) { op.RegionMetaCircuitBreakerSettings = config + } +} // WithBackoffer configures the client with backoffer. func WithBackoffer(bo *retry.Backoffer) ClientOption { From 45c805e7b60e8b1a178a7e67ae27bdd240396eae Mon Sep 17 00:00:00 2001 From: artem_danilov Date: Mon, 9 Dec 2024 19:15:03 -0800 Subject: [PATCH 8/8] improve testing Signed-off-by: artem_danilov --- client/circuitbreaker/circuit_breaker.go | 4 +- client/circuitbreaker/circuit_breaker_test.go | 45 ++++++++++++++++--- client/opt/option.go | 2 +- 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/client/circuitbreaker/circuit_breaker.go b/client/circuitbreaker/circuit_breaker.go index f2f23e5f977..b5a4c53ebb5 100644 --- a/client/circuitbreaker/circuit_breaker.go +++ b/client/circuitbreaker/circuit_breaker.go @@ -52,8 +52,8 @@ type Settings struct { HalfOpenSuccessCount uint32 } -// AlwaysOpenSettings is a configuration that never trips the circuit breaker. -var AlwaysOpenSettings = Settings{ +// AlwaysClosedSettings is a configuration that never trips the circuit breaker. +var AlwaysClosedSettings = Settings{ ErrorRateThresholdPct: 0, // never trips ErrorRateWindow: 10 * time.Second, // effectively results in testing for new settings every 10 seconds MinQPSForOpen: 10, diff --git a/client/circuitbreaker/circuit_breaker_test.go b/client/circuitbreaker/circuit_breaker_test.go index 412c7fc43ab..ca77b7f9f99 100644 --- a/client/circuitbreaker/circuit_breaker_test.go +++ b/client/circuitbreaker/circuit_breaker_test.go @@ -68,7 +68,7 @@ func TestCircuitBreaker_OpenState(t *testing.T) { re.Equal(StateOpen, cb.state.stateType) } -func TestCircuitBreaker_OpenState_Not_Enough_QPS(t *testing.T) { +func TestCircuitBreaker_CloseState_Not_Enough_QPS(t *testing.T) { re := require.New(t) cb := NewCircuitBreaker[int]("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) @@ -78,7 +78,7 @@ func TestCircuitBreaker_OpenState_Not_Enough_QPS(t *testing.T) { re.Equal(StateClosed, cb.state.stateType) } -func TestCircuitBreaker_OpenState_Not_Enough_Error_Rate(t *testing.T) { +func TestCircuitBreaker_CloseState_Not_Enough_Error_Rate(t *testing.T) { re := require.New(t) cb := NewCircuitBreaker[int]("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) @@ -176,12 +176,47 @@ func TestCircuitBreaker_Half_Open_Fail_Over_Pending_Count(t *testing.T) { re.Equal(uint32(1), cb.state.successCount) } +func TestCircuitBreaker_Count_Only_Requests_In_Same_Window(t *testing.T) { + re := require.New(t) + cb := NewCircuitBreaker[int]("test_cb", settings) + re.Equal(StateClosed, cb.state.stateType) + + start := make(chan bool) + wait := make(chan bool) + end := make(chan bool) + go func() { + defer func() { + end <- true + }() + _, err := cb.Execute(func() (int, Overloading, error) { + start <- true + <-wait + return 42, No, nil + }) + re.NoError(err) + }() + <-start // make sure the request is started + // assert running request is not counted + re.Equal(uint32(0), cb.state.successCount) + + // advance request to the next window + cb.advance(settings.ErrorRateWindow) + assertSucceeds(cb, re) + re.Equal(uint32(1), cb.state.successCount) + + // complete the request from the previous window + wait <- true // resume + <-end // wait for the request to complete + // assert request from last window is not counted + re.Equal(uint32(1), cb.state.successCount) +} + func TestCircuitBreaker_ChangeSettings(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", AlwaysOpenSettings) - driveQPS(cb, int(AlwaysOpenSettings.MinQPSForOpen*uint32(AlwaysOpenSettings.ErrorRateWindow.Seconds())), Yes, re) - cb.advance(AlwaysOpenSettings.ErrorRateWindow) + cb := NewCircuitBreaker[int]("test_cb", AlwaysClosedSettings) + driveQPS(cb, int(AlwaysClosedSettings.MinQPSForOpen*uint32(AlwaysClosedSettings.ErrorRateWindow.Seconds())), Yes, re) + cb.advance(AlwaysClosedSettings.ErrorRateWindow) assertSucceeds(cb, re) re.Equal(StateClosed, cb.state.stateType) diff --git a/client/opt/option.go b/client/opt/option.go index ff79f104154..9a80a895cc0 100644 --- a/client/opt/option.go +++ b/client/opt/option.go @@ -82,7 +82,7 @@ func NewOption() *Option { MaxRetryTimes: maxInitClusterRetries, EnableTSOFollowerProxyCh: make(chan struct{}, 1), InitMetrics: true, - RegionMetaCircuitBreakerSettings: cb.AlwaysOpenSettings, + RegionMetaCircuitBreakerSettings: cb.AlwaysClosedSettings, } co.dynamicOptions[MaxTSOBatchWaitInterval].Store(defaultMaxTSOBatchWaitInterval)