Skip to content

Commit

Permalink
Handle disable-server-routes flag in userspace router
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Jan 9, 2025
1 parent 28f5cd5 commit daf9359
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 54 deletions.
4 changes: 2 additions & 2 deletions client/firewall/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ import (
)

// NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}

// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface)
fm, err := uspfilter.Create(iface, disableServerRoutes)
if err != nil {
return nil, err
}
Expand Down
14 changes: 7 additions & 7 deletions client/firewall/create_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int

func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager)
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)

if !iface.IsUserspaceBind() {
return fm, err
Expand All @@ -47,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm)
return createUserspaceFirewall(iface, fm, disableServerRoutes)
}

func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
fm, err := createFW(iface)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
Expand All @@ -77,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
}
}

func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
} else {
fm, errUsp = uspfilter.Create(iface)
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
}

if errUsp != nil {
Expand Down
58 changes: 37 additions & 21 deletions client/firewall/uspfilter/uspfilter.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,23 @@ type decoder struct {
}

// Create userspace firewall manager constructor
func Create(iface common.IFaceMapper) (*Manager, error) {
return create(iface)
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
return create(iface, disableServerRoutes)
}

func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
mgr, err := create(iface)
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
mgr, err := create(iface, disableServerRoutes)
if err != nil {
return nil, err
}

mgr.nativeFirewall = nativeFirewall

if disableServerRoutes {
// skip native vs userspace router logic altogether
return mgr, nil
}

if forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter)); forceUserspaceRouter {
log.Info("userspace routing is forced")
return mgr, nil
Expand All @@ -125,7 +130,7 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
return mgr, nil
}

func create(iface common.IFaceMapper) (*Manager, error) {
func create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))

m := &Manager{
Expand All @@ -147,6 +152,7 @@ func create(iface common.IFaceMapper) (*Manager, error) {
routeRules: make(map[string]RouteRule),
wgIface: iface,
localipmanager: newLocalIPManager(),
routingEnabled: false,
stateful: !disableConntrack,
// TODO: support changing log level from logrus
logger: nblog.NewFromLogrus(log.StandardLogger()),
Expand All @@ -166,23 +172,16 @@ func create(iface common.IFaceMapper) (*Manager, error) {
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}

if disableRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting)); disableRouting {
disableUspRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting))
if disableUspRouting || disableServerRoutes {
log.Info("userspace routing is disabled")
return m, nil
} else {
m.routingEnabled = true
}

intf := iface.GetWGDevice()
if intf == nil {
log.Info("forwarding not supported")
// Only supported in userspace mode as we need to inject packets back into wireguard directly
} else {
var err error
m.forwarder, err = forwarder.New(iface, m.logger, m.netstack)
if err != nil {
log.Errorf("failed to create forwarder: %v", err)
} else {
m.routingEnabled = true
}
// netstack needs the forwarder for local traffic
if m.netstack || m.routingEnabled {
m.initForwarder(iface)
}

if err := iface.SetFilter(m); err != nil {
Expand All @@ -191,6 +190,25 @@ func create(iface common.IFaceMapper) (*Manager, error) {
return m, nil
}

func (m *Manager) initForwarder(iface common.IFaceMapper) {
// Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := iface.GetWGDevice()
if intf == nil {
log.Info("forwarding not supported")
m.routingEnabled = false
return
}

forwarder, err := forwarder.New(iface, m.logger, m.netstack)
if err != nil {
log.Errorf("failed to create forwarder: %v", err)
m.routingEnabled = false
return
}

m.forwarder = forwarder
}

func (m *Manager) Init(*statemanager.Manager) error {
return nil
}
Expand Down Expand Up @@ -509,8 +527,6 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
// dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
// TODO: Disable router if --disable-server-router is set

m.mutex.RLock()
defer m.mutex.RUnlock()

Expand Down
16 changes: 8 additions & 8 deletions client/firewall/uspfilter/uspfilter_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
Expand Down Expand Up @@ -203,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
Expand Down Expand Up @@ -251,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
Expand Down Expand Up @@ -450,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
Expand Down Expand Up @@ -577,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {

manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
Expand Down Expand Up @@ -668,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {

manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
Expand Down Expand Up @@ -787,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {

manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
Expand Down Expand Up @@ -875,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {

manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
Expand Down
4 changes: 2 additions & 2 deletions client/firewall/uspfilter/uspfilter_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestPeerACLFiltering(t *testing.T) {
},
}

manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, false)
require.NoError(t, err)
require.NotNil(t, manager)

Expand Down Expand Up @@ -249,7 +249,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
},
}

manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, false)
require.NoError(tb, err)
require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled)
Expand Down
20 changes: 10 additions & 10 deletions client/firewall/uspfilter/uspfilter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}

m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
Expand All @@ -82,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
},
}

m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
Expand Down Expand Up @@ -117,7 +117,7 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}

m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
Expand Down Expand Up @@ -210,7 +210,7 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
require.NoError(t, err)

manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
Expand Down Expand Up @@ -263,7 +263,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}

m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
Expand Down Expand Up @@ -307,7 +307,7 @@ func TestNotMatchByIP(t *testing.T) {
},
}

m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
Expand Down Expand Up @@ -376,7 +376,7 @@ func TestRemovePacketHook(t *testing.T) {
}

// creating manager instance
manager, err := Create(iface)
manager, err := Create(iface, false)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
Expand Down Expand Up @@ -422,7 +422,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
require.NoError(t, err)

manager.wgNetwork = &net.IPNet{
Expand Down Expand Up @@ -508,7 +508,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, false)
require.NoError(t, err)
time.Sleep(time.Second)

Expand Down Expand Up @@ -539,7 +539,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
require.NoError(t, err)

manager.wgNetwork = &net.IPNet{
Expand Down
4 changes: 2 additions & 2 deletions client/internal/acl/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestDefaultManager(t *testing.T) {
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()

// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil)
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
if err != nil {
t.Errorf("create firewall: %v", err)
return
Expand Down Expand Up @@ -346,7 +346,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()

// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil)
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
if err != nil {
t.Errorf("create firewall: %v", err)
return
Expand Down
2 changes: 1 addition & 1 deletion client/internal/dns/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err
}

pf, err := uspfilter.Create(wgIface)
pf, err := uspfilter.Create(wgIface, false)
if err != nil {
t.Fatalf("failed to create uspfilter: %v", err)
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ func (e *Engine) createFirewall() error {
}

var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes)
if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err)
return nil
Expand Down

0 comments on commit daf9359

Please sign in to comment.