Skip to content

Commit 560438f

Browse files
committed
Add fwmark support
1 parent db2bd81 commit 560438f

File tree

3 files changed

+33
-11
lines changed

3 files changed

+33
-11
lines changed

main.go

+5
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ var (
4949
denyRange IPPortRangeSlice
5050
dnsTTL time.Duration
5151
readyFile string
52+
fwmark int
5253
)
5354

5455
func initFlagSet(flag *flag.FlagSet) {
@@ -73,6 +74,7 @@ func initFlagSet(flag *flag.FlagSet) {
7374
flag.Var(&denyRange, "deny", "When routing, deny specified IP prefix and port range")
7475
flag.DurationVar(&dnsTTL, "dns-ttl", time.Duration(5*time.Second), "For how long to cache DNS in case of dns labels passed to forward target.")
7576
flag.StringVar(&readyFile, "ready-file", "", "After initialization, write a byte to this file to signal readiness")
77+
flag.IntVar(&fwmark, "fwmark", 0, "Set fwmark on outbound packets")
7678
}
7779

7880
func main() {
@@ -99,6 +101,8 @@ type State struct {
99101
denyRange IPPortRangeSlice
100102

101103
srcIPs SrcIPs
104+
105+
fwmark int
102106
}
103107

104108
func Main(programName string, args []string) int {
@@ -170,6 +174,7 @@ func Main(programName string, args []string) int {
170174
state.srcIPs.srcIPv6 = sourceIPv6.ip
171175
state.allowRange = allowRange
172176
state.denyRange = denyRange
177+
state.fwmark = fwmark
173178

174179
logConnections = !quiet
175180

net.go

+20-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strconv"
88
"strings"
99
"sync"
10+
"syscall"
1011
"time"
1112

1213
"gvisor.dev/gvisor/pkg/tcpip"
@@ -243,8 +244,22 @@ func netParseOrResolveIP(h string) (_ip net.IP, _resolved bool, _err error) {
243244
return ip, true, err
244245
}
245246

246-
func OutboundDial(srcIPs *SrcIPs, dst net.Addr) (net.Conn, error) {
247+
func OutboundDial(state *State, dst net.Addr) (net.Conn, error) {
248+
srcIPs := &state.srcIPs
247249
network := dst.Network()
250+
dialer := &net.Dialer{}
251+
if state.fwmark != 0 {
252+
dialer.Control = func(network, address string, c syscall.RawConn) error {
253+
var controlErr error
254+
err := c.Control(func(fd uintptr) {
255+
controlErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, state.fwmark)
256+
})
257+
if err != nil {
258+
return err
259+
}
260+
return controlErr
261+
}
262+
}
248263
if network == "tcp" {
249264
dstTcp := dst.(*net.TCPAddr)
250265
var srcTcp *net.TCPAddr
@@ -254,7 +269,8 @@ func OutboundDial(srcIPs *SrcIPs, dst net.Addr) (net.Conn, error) {
254269
if srcIPs != nil && dstTcp.IP.To4() == nil && srcIPs.srcIPv6 != nil {
255270
srcTcp = &net.TCPAddr{IP: srcIPs.srcIPv6}
256271
}
257-
return net.DialTCP(network, srcTcp, dstTcp)
272+
dialer.LocalAddr = srcTcp
273+
return dialer.Dial(network, dstTcp.String())
258274
}
259275
if network == "udp" {
260276
dstUdp := dst.(*net.UDPAddr)
@@ -265,7 +281,8 @@ func OutboundDial(srcIPs *SrcIPs, dst net.Addr) (net.Conn, error) {
265281
if srcIPs != nil && dstUdp.IP.To4() == nil && srcIPs.srcIPv6 != nil {
266282
srcUdp = &net.UDPAddr{IP: srcIPs.srcIPv6}
267283
}
268-
return net.DialUDP(network, srcUdp, dstUdp)
284+
dialer.LocalAddr = srcUdp
285+
return dialer.Dial(network, dstUdp.String())
269286
}
270287
return nil, fmt.Errorf("not tcp/udp")
271288
}

routing.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ func UdpRoutingHandler(s *stack.Stack, state *State) func(*udp.ForwarderRequest)
7070

7171
go func() {
7272
if rf != nil {
73-
RemoteForward(conn, &state.srcIPs, rf)
73+
RemoteForward(conn, state, rf)
7474
} else {
75-
RoutingForward(conn, &state.srcIPs, loc)
75+
RoutingForward(conn, state, loc)
7676
}
7777
}()
7878
}
@@ -118,22 +118,22 @@ func TcpRoutingHandler(state *State) func(*tcp.ForwarderRequest) {
118118

119119
go func() {
120120
if rf != nil {
121-
RemoteForward(conn, &state.srcIPs, rf)
121+
RemoteForward(conn, state, rf)
122122
} else {
123-
RoutingForward(conn, &state.srcIPs, loc)
123+
RoutingForward(conn, state, loc)
124124
}
125125
}()
126126
}
127127
return h
128128
}
129129

130-
func RoutingForward(guest KaConn, srcIPs *SrcIPs, loc net.Addr) {
130+
func RoutingForward(guest KaConn, state *State, loc net.Addr) {
131131
// Cache guest.RemoteAddr() because it becomes nil on
132132
// guest.Close().
133133
guestRemoteAddr := guest.RemoteAddr()
134134

135135
var pe ProxyError
136-
xhost, err := OutboundDial(srcIPs, loc)
136+
xhost, err := OutboundDial(state, loc)
137137
if err != nil {
138138
SetResetOnClose(guest)
139139
guest.Close()
@@ -173,7 +173,7 @@ func RoutingForward(guest KaConn, srcIPs *SrcIPs, loc net.Addr) {
173173
}
174174
}
175175

176-
func RemoteForward(guest KaConn, srcIPs *SrcIPs, rf *FwdAddr) {
176+
func RemoteForward(guest KaConn, state *State, rf *FwdAddr) {
177177
// Cache guest.RemoteAddr() because it becomes nil on
178178
// guest.Close().
179179
guestRemoteAddr := guest.RemoteAddr()
@@ -190,7 +190,7 @@ func RemoteForward(guest KaConn, srcIPs *SrcIPs, rf *FwdAddr) {
190190
err)
191191
return
192192
}
193-
xhost, err := OutboundDial(srcIPs, hostAddr)
193+
xhost, err := OutboundDial(state, hostAddr)
194194
if err != nil {
195195
SetResetOnClose(guest)
196196
guest.Close()

0 commit comments

Comments
 (0)