diff --git a/balancer/rls/control_channel.go b/balancer/rls/control_channel.go index 60e6a021d133..de28fc3e9b4e 100644 --- a/balancer/rls/control_channel.go +++ b/balancer/rls/control_channel.go @@ -21,6 +21,7 @@ package rls import ( "context" "fmt" + "sync" "time" "google.golang.org/grpc" @@ -29,7 +30,6 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal" - "google.golang.org/grpc/internal/buffer" internalgrpclog "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/pretty" @@ -57,12 +57,12 @@ type controlChannel struct { // hammering the RLS service while it is overloaded or down. throttler adaptiveThrottler - cc *grpc.ClientConn - client rlsgrpc.RouteLookupServiceClient - logger *internalgrpclog.PrefixLogger - connectivityStateCh *buffer.Unbounded - unsubscribe func() - monitorDoneCh chan struct{} + cc *grpc.ClientConn + client rlsgrpc.RouteLookupServiceClient + logger *internalgrpclog.PrefixLogger + unsubscribe func() + seenTransientFailure bool + mu sync.Mutex } // newControlChannel creates a controlChannel to rlsServerName and uses @@ -70,11 +70,9 @@ type controlChannel struct { // gRPC channel. func newControlChannel(rlsServerName, serviceConfig string, rpcTimeout time.Duration, bOpts balancer.BuildOptions, backToReadyFunc func()) (*controlChannel, error) { ctrlCh := &controlChannel{ - rpcTimeout: rpcTimeout, - backToReadyFunc: backToReadyFunc, - throttler: newAdaptiveThrottler(), - connectivityStateCh: buffer.NewUnbounded(), - monitorDoneCh: make(chan struct{}), + rpcTimeout: rpcTimeout, + backToReadyFunc: backToReadyFunc, + throttler: newAdaptiveThrottler(), } ctrlCh.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[rls-control-channel %p] ", ctrlCh)) @@ -92,7 +90,6 @@ func newControlChannel(rlsServerName, serviceConfig string, rpcTimeout time.Dura ctrlCh.cc.Connect() ctrlCh.client = rlsgrpc.NewRouteLookupServiceClient(ctrlCh.cc) ctrlCh.logger.Infof("Control channel created to RLS server at: %v", rlsServerName) - go ctrlCh.monitorConnectivityState() return ctrlCh, nil } @@ -101,7 +98,34 @@ func (cc *controlChannel) OnMessage(msg any) { if !ok { panic(fmt.Sprintf("Unexpected message type %T , wanted connectectivity.State type", msg)) } - cc.connectivityStateCh.Put(st) + + cc.mu.Lock() + defer cc.mu.Unlock() + + switch st { + case connectivity.Ready: + // Only reset backoff when transitioning from TRANSIENT_FAILURE to READY. + // This indicates the RLS server has recovered from being unreachable, so + // we reset backoff state in all cache entries to allow pending RPCs to + // proceed immediately. We skip benign transitions like READY → IDLE → READY + // since those don't represent actual failures. + if cc.seenTransientFailure { + cc.logger.Infof("Control channel back to READY after TRANSIENT_FAILURE") + cc.seenTransientFailure = false + if cc.backToReadyFunc != nil { + cc.backToReadyFunc() + } + } else { + cc.logger.Infof("Control channel is READY") + } + case connectivity.TransientFailure: + // Track that we've entered TRANSIENT_FAILURE state so we know to reset + // backoffs when we recover to READY. + cc.logger.Infof("Control channel is TRANSIENT_FAILURE") + cc.seenTransientFailure = true + default: + cc.logger.Infof("Control channel connectivity state is %s", st) + } } // dialOpts constructs the dial options for the control plane channel. @@ -148,68 +172,8 @@ func (cc *controlChannel) dialOpts(bOpts balancer.BuildOptions, serviceConfig st return dopts, nil } -func (cc *controlChannel) monitorConnectivityState() { - cc.logger.Infof("Starting connectivity state monitoring goroutine") - defer close(cc.monitorDoneCh) - - // Since we use two mechanisms to deal with RLS server being down: - // - adaptive throttling for the channel as a whole - // - exponential backoff on a per-request basis - // we need a way to avoid double-penalizing requests by counting failures - // toward both mechanisms when the RLS server is unreachable. - // - // To accomplish this, we monitor the state of the control plane channel. If - // the state has been TRANSIENT_FAILURE since the last time it was in state - // READY, and it then transitions into state READY, we push on a channel - // which is being read by the LB policy. - // - // The LB the policy will iterate through the cache to reset the backoff - // timeouts in all cache entries. Specifically, this means that it will - // reset the backoff state and cancel the pending backoff timer. Note that - // when cancelling the backoff timer, just like when the backoff timer fires - // normally, a new picker is returned to the channel, to force it to - // re-process any wait-for-ready RPCs that may still be queued if we failed - // them while we were in backoff. However, we should optimize this case by - // returning only one new picker, regardless of how many backoff timers are - // cancelled. - - // Wait for the control channel to become READY for the first time. - for s, ok := <-cc.connectivityStateCh.Get(); s != connectivity.Ready; s, ok = <-cc.connectivityStateCh.Get() { - if !ok { - return - } - - cc.connectivityStateCh.Load() - if s == connectivity.Shutdown { - return - } - } - cc.connectivityStateCh.Load() - cc.logger.Infof("Connectivity state is READY") - - for { - s, ok := <-cc.connectivityStateCh.Get() - if !ok { - return - } - cc.connectivityStateCh.Load() - - if s == connectivity.Shutdown { - return - } - if s == connectivity.Ready { - cc.logger.Infof("Control channel back to READY") - cc.backToReadyFunc() - } - - cc.logger.Infof("Connectivity state is %s", s) - } -} - func (cc *controlChannel) close() { cc.unsubscribe() - cc.connectivityStateCh.Close() - <-cc.monitorDoneCh cc.cc.Close() cc.logger.Infof("Shutdown") } diff --git a/balancer/rls/control_channel_test.go b/balancer/rls/control_channel_test.go index 5a30820c3b47..dbe41db893a1 100644 --- a/balancer/rls/control_channel_test.go +++ b/balancer/rls/control_channel_test.go @@ -26,6 +26,7 @@ import ( "fmt" "os" "regexp" + "sync" "testing" "time" @@ -33,6 +34,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal" rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1" @@ -463,3 +465,119 @@ func (s) TestNewControlChannelUnsupportedCredsBundle(t *testing.T) { t.Fatal("newControlChannel succeeded when expected to fail") } } + +// TestControlChannelConnectivityStateTransitions verifies that the control +// channel only resets backoff when recovering from TRANSIENT_FAILURE, not +// when going through benign state changes like READY → IDLE → READY. +func (s) TestControlChannelConnectivityStateTransitions(t *testing.T) { + tests := []struct { + name string + states []connectivity.State + wantCallbackCount int + }{ + { + name: "READY → TRANSIENT_FAILURE → READY triggers callback", + states: []connectivity.State{ + connectivity.TransientFailure, + connectivity.Ready, + }, + wantCallbackCount: 1, + }, + { + name: "READY → IDLE → READY does not trigger callback", + states: []connectivity.State{ + connectivity.Idle, + connectivity.Ready, + }, + wantCallbackCount: 0, + }, + { + name: "Multiple failures trigger callback each time", + states: []connectivity.State{ + connectivity.TransientFailure, + connectivity.Ready, + connectivity.TransientFailure, + connectivity.Ready, + }, + wantCallbackCount: 2, + }, + { + name: "IDLE between failures doesn't affect callback", + states: []connectivity.State{ + connectivity.TransientFailure, + connectivity.Idle, + connectivity.Ready, + }, + wantCallbackCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Start an RLS server + rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil) + + // Setup callback to count invocations + var mu sync.Mutex + var callbackCount int + // Buffered channel large enough to never block + callbackInvoked := make(chan struct{}, 100) + callback := func() { + mu.Lock() + callbackCount++ + mu.Unlock() + // Send to channel - should never block with large buffer + callbackInvoked <- struct{}{} + } + + // Create control channel + ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{}, callback) + if err != nil { + t.Fatalf("Failed to create control channel: %v", err) + } + defer ctrlCh.close() + + // Inject all test states + for _, state := range tt.states { + ctrlCh.OnMessage(state) + } + + // Wait for all expected callbacks with timeout + callbackTimeout := time.NewTimer(defaultTestTimeout) + defer callbackTimeout.Stop() + + receivedCallbacks := 0 + for receivedCallbacks < tt.wantCallbackCount { + select { + case <-callbackInvoked: + receivedCallbacks++ + case <-callbackTimeout.C: + mu.Lock() + got := callbackCount + mu.Unlock() + t.Fatalf("Timeout waiting for callbacks: expected %d, received %d via channel, callback count is %d", tt.wantCallbackCount, receivedCallbacks, got) + } + } + + // Verify final callback count matches expected + mu.Lock() + gotCallbackCount := callbackCount + mu.Unlock() + + if gotCallbackCount != tt.wantCallbackCount { + t.Errorf("Got %d callback invocations, want %d", gotCallbackCount, tt.wantCallbackCount) + } + + // Ensure no extra callbacks are invoked + select { + case <-callbackInvoked: + mu.Lock() + final := callbackCount + mu.Unlock() + t.Fatalf("Received more callbacks than expected: got %d, want %d", final, tt.wantCallbackCount) + case <-time.After(50 * time.Millisecond): + // Expected: no more callbacks + } + }) + } +}