diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index e55a0705556..541f4c5a569 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -88,13 +88,15 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD if len(nsConfig.NameServers) == 0 { continue } - if nsConfig.Primary { + + if nsConfig.Primary && nsConfig.Enabled { config.RouteAll = true } for _, domain := range nsConfig.Domains { config.Domains = append(config.Domains, DomainConfig{ Domain: strings.TrimSuffix(domain, "."), + Disabled: !nsConfig.Enabled, MatchOnly: !nsConfig.SearchDomainsEnabled, }) } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index a4651ebb5b0..974576b5684 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -5,6 +5,7 @@ import ( "fmt" "net/netip" "runtime" + "slices" "strings" "sync" @@ -297,6 +298,11 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.service.Stop() } + // update each NameServerGroup config in nbdns.Config base on peerCount and IP.IsPrivate() + if runtime.GOOS != "android" && runtime.GOOS != "ios" { + s.toggleNameServerGroupsOnStatus(update.NameServerGroups) + } + localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) if err != nil { return fmt.Errorf("not applying dns update, error: %v", err) @@ -308,6 +314,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic s.updateMux(muxUpdates) + s.updateLocalResolver(localRecords) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) @@ -359,7 +366,6 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) } func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) { - var muxUpdates []muxUpdate for _, nsGroup := range nameServerGroups { if len(nsGroup.NameServers) == 0 { @@ -403,7 +409,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam // contains this upstream settings (temporal deactivation not removed it) handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler) - if nsGroup.Primary { + if nsGroup.Primary && nsGroup.Enabled { muxUpdates = append(muxUpdates, muxUpdate{ domain: nbdns.RootZone, handler: handler, @@ -421,6 +427,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam handler.stop() return nil, fmt.Errorf("received a nameserver group with an empty domain element") } + if !nsGroup.Enabled { + continue + } muxUpdates = append(muxUpdates, muxUpdate{ domain: domain, handler: handler, @@ -484,6 +493,19 @@ func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord s.localResolver.registeredMap = updatedMap } +func (s *DefaultServer) toggleNameServerGroupsOnStatus(nameServerGroups []*nbdns.NameServerGroup) { + peerCount := s.statusRecorder.GetConnectedPeersCount() + for _, nsGroup := range nameServerGroups { + var hasPublicNameServer bool + for _, s := range nsGroup.NameServers { + if !s.IP.IsPrivate() { + hasPublicNameServer = true + } + } + nsGroup.Enabled = hasPublicNameServer || (peerCount >= 1) + } +} + func getNSHostPort(ns nbdns.NameServer) string { return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) } @@ -495,7 +517,6 @@ func (s *DefaultServer) upstreamCallbacks( nsGroup *nbdns.NameServerGroup, handler dns.Handler, ) (deactivate func(error), reactivate func()) { - var removeIndex map[string]int deactivate = func(err error) { s.mux.Lock() defer s.mux.Unlock() @@ -503,21 +524,15 @@ func (s *DefaultServer) upstreamCallbacks( l := log.WithField("nameservers", nsGroup.NameServers) l.Info("Temporarily deactivating nameservers group due to timeout") - removeIndex = make(map[string]int) - for _, domain := range nsGroup.Domains { - removeIndex[domain] = -1 - } if nsGroup.Primary { - removeIndex[nbdns.RootZone] = -1 s.currentConfig.RouteAll = false s.service.DeregisterMux(nbdns.RootZone) } for i, item := range s.currentConfig.Domains { - if _, found := removeIndex[item.Domain]; found { + if slices.Contains(nsGroup.Domains, item.Domain) { s.currentConfig.Domains[i].Disabled = true s.service.DeregisterMux(item.Domain) - removeIndex[item.Domain] = i } } @@ -530,18 +545,16 @@ func (s *DefaultServer) upstreamCallbacks( } s.updateNSState(nsGroup, err, false) - } reactivate = func() { s.mux.Lock() defer s.mux.Unlock() - for domain, i := range removeIndex { - if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain { - continue + for i, item := range s.currentConfig.Domains { + if slices.Contains(nsGroup.Domains, item.Domain) { + s.currentConfig.Domains[i].Disabled = false + s.service.RegisterMux(item.Domain, handler) } - s.currentConfig.Domains[i].Disabled = false - s.service.RegisterMux(domain, handler) } l := log.WithField("nameservers", nsGroup.NameServers) @@ -588,17 +601,22 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { for _, group := range groups { var servers []string + var nsError error + if !group.Enabled { + nsError = fmt.Errorf("no peers connected") + } for _, ns := range group.NameServers { servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) } + // Automatically disbled if peer == 0 and IP is private state := peer.NSGroupState{ ID: generateGroupKey(group), Servers: servers, Domains: group.Domains, // The probe will determine the state, default enabled - Enabled: true, - Error: nil, + Enabled: group.Enabled, + Error: nsError, } states = append(states, state) } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index b9552bc17c0..62cd5916f32 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -7,6 +7,7 @@ import ( "net/netip" "os" "strings" + "sync" "testing" "time" @@ -126,10 +127,12 @@ func TestUpdateDNSServer(t *testing.T) { }, NameServerGroups: []*nbdns.NameServerGroup{ { + Enabled: true, Domains: []string{"netbird.io"}, NameServers: nameServers, }, { + Enabled: true, NameServers: nameServers, Primary: true, }, @@ -154,6 +157,7 @@ func TestUpdateDNSServer(t *testing.T) { }, NameServerGroups: []*nbdns.NameServerGroup{ { + Enabled: true, Domains: []string{"netbird.io"}, NameServers: nameServers, }, @@ -279,7 +283,16 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) + statusRecorder := peer.NewRecorder("https://mgm") + key := "abc" + statusRecorder.AddPeer(key, "abc.netbird") + statusRecorder.UpdatePeerState(peer.State{ + PubKey: key, + Mux: new(sync.RWMutex), + ConnStatus: peer.StatusConnected, + ConnStatusUpdate: time.Now(), + }) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", statusRecorder) if err != nil { t.Fatal(err) } @@ -427,10 +440,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { { Domains: []string{"netbird.io"}, NameServers: nameServers, + Enabled: true, }, { NameServers: nameServers, Primary: true, + Enabled: true, }, }, } diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index b3baf2fa8fd..10a69771f1f 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -56,7 +56,7 @@ type upstreamResolverBase struct { func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase { ctx, cancel := context.WithCancel(ctx) - return &upstreamResolverBase{ + resolverBase := &upstreamResolverBase{ ctx: ctx, cancel: cancel, upstreamTimeout: upstreamTimeout, @@ -64,6 +64,92 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) * failsTillDeact: failsTillDeact, statusRecorder: statusRecorder, } + + go resolverBase.watchPeersConnStatusChanges() + + return resolverBase +} + +func (u *upstreamResolverBase) watchPeersConnStatusChanges() { + var probeRunning atomic.Bool + var cancelBackOff context.CancelFunc + + exponentialBackOff := &backoff.ExponentialBackOff{ + InitialInterval: 200 * time.Millisecond, + RandomizationFactor: 0.5, + Multiplier: 1.1, + MaxInterval: 5 * time.Second, + MaxElapsedTime: 15 * time.Second, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + } + operation := func() error { + select { + case <-u.ctx.Done(): + return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context : %s", u.upstreamServers, u.ctx.Err())) + default: + } + + u.probeAvailability() + if u.disabled { + return fmt.Errorf("probe failed") + } + return nil + } + + continualProbe := func() { + // probe continually for 30s when peer count >= 1 + if u.statusRecorder.GetConnectedPeersCount() == 0 { + log.Debug("O peers connected, running one more DNS probe") + // cancel backoff operation + if cancelBackOff != nil { + cancelBackOff() + cancelBackOff = nil + } + u.probeAvailability() + return + } + + if probeRunning.Load() { + log.Info("restart DNS probing") + cancelBackOff() + cancelBackOff = nil + } + defer func() { + u.mutex.Lock() + log.Infof("DNS probe finished, servers %s disabled: %t", u.upstreamServers, u.disabled) + u.mutex.Unlock() + probeRunning.Store(false) + }() + probeRunning.Store(true) + + ctx, cancel := context.WithCancel(context.Background()) + cancelBackOff = cancel + err := backoff.Retry(func() error { + select { + case <-ctx.Done(): + return backoff.Permanent(ctx.Err()) + default: + return operation() + } + }, backoff.WithContext(exponentialBackOff, ctx)) + cancelBackOff = nil + if err != nil { + log.Warnf("DNS probe (peer ConnStatus change) stopped: %s", err) + u.disable(err) + return + } + } + + for { + select { + case <-u.ctx.Done(): + return + case <-u.statusRecorder.GetPeersConnStatusChangeNotifier(): + log.Debugf("probing DNS availability on/off for 30s") + go continualProbe() + } + } } func (u *upstreamResolverBase) stop() { @@ -163,7 +249,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) { } // probeAvailability tests all upstream servers simultaneously and -// disables the resolver if none work +// disables/enable the resolver func (u *upstreamResolverBase) probeAvailability() { u.mutex.Lock() defer u.mutex.Unlock() @@ -174,11 +260,6 @@ func (u *upstreamResolverBase) probeAvailability() { default: } - // avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact - if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact { - return - } - var success bool var mu sync.Mutex var wg sync.WaitGroup @@ -190,7 +271,7 @@ func (u *upstreamResolverBase) probeAvailability() { wg.Add(1) go func() { defer wg.Done() - err := u.testNameserver(upstream, 500*time.Millisecond) + err := u.testNameserver(upstream, probeTimeout) if err != nil { errors = multierror.Append(errors, err) log.Warnf("probing upstream nameserver %s: %s", upstream, err) @@ -208,6 +289,15 @@ func (u *upstreamResolverBase) probeAvailability() { // didn't find a working upstream server, let's disable and try later if !success { u.disable(errors.ErrorOrNil()) + return + } + + if u.disabled { + log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) + u.failsCount.Store(0) + u.successCount.Add(1) + u.reactivate() + u.disabled = false } } @@ -223,37 +313,22 @@ func (u *upstreamResolverBase) waitUntilResponse() { Clock: backoff.SystemClock, } - operation := func() error { - select { - case <-u.ctx.Done(): - return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers)) - default: + err := backoff.Retry(func() error { + if u.disabled { + u.probeAvailability() } - for _, upstream := range u.upstreamServers { - if err := u.testNameserver(upstream, probeTimeout); err != nil { - log.Tracef("upstream check for %s: %s", upstream, err) - } else { - // at least one upstream server is available, stop probing - return nil - } + // check if still disbaled + if u.disabled { + log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff()) + return fmt.Errorf("upstream check call error") } - - log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff()) - return fmt.Errorf("upstream check call error") - } - - err := backoff.Retry(operation, exponentialBackOff) + return nil + }, exponentialBackOff) if err != nil { log.Warn(err) return } - - log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) - u.failsCount.Store(0) - u.successCount.Add(1) - u.reactivate() - u.disabled = false } // isTimeout returns true if the given error is a network timeout error. diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index c1251dcc1e9..c754f22d1f3 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -8,10 +8,10 @@ import ( "time" "github.com/miekg/dns" + "github.com/netbirdio/netbird/client/internal/peer" ) func TestUpstreamResolver_ServeDNS(t *testing.T) { - testCases := []struct { name string inputMSG *dns.Msg @@ -58,7 +58,12 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) - resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil) + statusRecorder := peer.NewRecorder("https://mgm") + key := "abc" + // Public resolvers being used so peer not required + statusRecorder.AddPeer(key, "abc.netbird") + // PubKey: key, + resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, statusRecorder, nil) resolver.upstreamServers = testCase.InputServers resolver.upstreamTimeout = testCase.timeout if testCase.cancelCTX { diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index a7cfb95c4c7..6afed2e1ea0 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -120,23 +120,24 @@ type FullStatus struct { // Status holds a state of peers, signal, management connections and relays type Status struct { - mux sync.Mutex - peers map[string]State - changeNotify map[string]chan struct{} - signalState bool - signalError error - managementState bool - managementError error - relayStates []relay.ProbeResult - localPeer LocalPeerState - offlinePeers []State - mgmAddress string - signalAddress string - notifier *notifier - rosenpassEnabled bool - rosenpassPermissive bool - nsGroupStates []NSGroupState - resolvedDomainsStates map[domain.Domain][]netip.Prefix + mux sync.Mutex + peers map[string]State + changeNotify map[string]chan struct{} + signalState bool + signalError error + managementState bool + managementError error + relayStates []relay.ProbeResult + localPeer LocalPeerState + offlinePeers []State + mgmAddress string + signalAddress string + notifier *notifier + rosenpassEnabled bool + rosenpassPermissive bool + nsGroupStates []NSGroupState + resolvedDomainsStates map[domain.Domain][]netip.Prefix + aPeerConnStatusChanged chan struct{} // To reduce the number of notification invocation this bool will be true when need to call the notification // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events @@ -215,6 +216,7 @@ func (d *Status) RemovePeer(peerPubKey string) error { // UpdatePeerState updates peer status func (d *Status) UpdatePeerState(receivedState State) error { + var connStatusChanged bool d.mux.Lock() defer d.mux.Unlock() @@ -227,8 +229,9 @@ func (d *Status) UpdatePeerState(receivedState State) error { peerState.IP = receivedState.IP } - if receivedState.GetRoutes() != nil { - peerState.SetRoutes(receivedState.GetRoutes()) + routes := receivedState.GetRoutes() + if routes != nil { + peerState.SetRoutes(routes) } skipNotification := shouldSkipNotify(receivedState, peerState) @@ -243,10 +246,16 @@ func (d *Status) UpdatePeerState(receivedState State) error { peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint peerState.RosenpassEnabled = receivedState.RosenpassEnabled + connStatusChanged = true } d.peers[receivedState.PubKey] = peerState + if connStatusChanged && d.aPeerConnStatusChanged != nil && (peerState.ConnStatus == StatusConnected || peerState.ConnStatus == StatusDisconnected) { + close(d.aPeerConnStatusChanged) + d.aPeerConnStatusChanged = nil + } + if skipNotification { return nil } @@ -323,6 +332,17 @@ func (d *Status) FinishPeerListModifications() { d.notifyPeerListChanged() } +// GetPeersConnStatusChangeNotifier returns a change notifier channel for routing peer list +func (d *Status) GetPeersConnStatusChangeNotifier() <-chan struct{} { + d.mux.Lock() + defer d.mux.Unlock() + if d.aPeerConnStatusChanged == nil { + ch := make(chan struct{}) + d.aPeerConnStatusChanged = ch + } + return d.aPeerConnStatusChanged +} + // GetPeerStateChangeNotifier returns a change notifier channel for a peer func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { d.mux.Lock() @@ -342,6 +362,19 @@ func (d *Status) GetLocalPeerState() LocalPeerState { return d.localPeer } +// GetConnectedPeersCount returns number of peers connected +func (d *Status) GetConnectedPeersCount() int { + d.mux.Lock() + defer d.mux.Unlock() + var connectedCount int + for _, peer := range d.peers { + if peer.ConnStatus == StatusConnected { + connectedCount++ + } + } + return connectedCount +} + // UpdateLocalPeerState updates local peer status func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { d.mux.Lock() diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index a4a6e608132..f01c3e381bd 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -77,6 +77,12 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) { assert.Equal(t, fqdn, state.FQDN, "fqdn should be equal") } +func TestGetPeersConnStatusChangeNotifierLogic(t *testing.T) { + status := NewRecorder("https://mgm") + ch := status.GetPeersConnStatusChangeNotifier() + assert.NotNil(t, ch, "channel shouldn't be nil") +} + func TestGetPeerStateChangeNotifierLogic(t *testing.T) { key := "abc" ip := "10.10.10.10"