diff --git a/go.mod b/go.mod index 9e498f52..50a38a37 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/mholt/caddy-l4 -go 1.22.0 - -toolchain go1.23.0 +go 1.23.0 require ( github.com/caddyserver/caddy/v2 v2.8.4 @@ -14,6 +12,7 @@ require ( go.uber.org/zap v1.27.0 golang.org/x/crypto v0.28.0 golang.org/x/net v0.30.0 + golang.org/x/sys v0.33.0 golang.org/x/time v0.7.0 ) @@ -140,7 +139,6 @@ require ( golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect golang.org/x/mod v0.18.0 // indirect golang.org/x/sync v0.8.0 // indirect - golang.org/x/sys v0.27.0 // indirect golang.org/x/term v0.25.0 // indirect golang.org/x/text v0.19.0 // indirect golang.org/x/tools v0.22.0 // indirect diff --git a/go.sum b/go.sum index 2e20969a..ea7c1691 100644 --- a/go.sum +++ b/go.sum @@ -582,8 +582,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= -golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= diff --git a/layer4/app.go b/layer4/app.go index 0e69a48f..4b83c726 100644 --- a/layer4/app.go +++ b/layer4/app.go @@ -66,7 +66,7 @@ func (a *App) Provision(ctx caddy.Context) error { func (a *App) Start() error { for _, s := range a.Servers { for _, addr := range s.listenAddrs { - listeners, err := addr.ListenAll(a.ctx, net.ListenConfig{}) + listeners, err := addr.ListenAll(a.ctx, listenConfig) if err != nil { return err } diff --git a/layer4/packet_conn_linux.go b/layer4/packet_conn_linux.go new file mode 100644 index 00000000..e469c254 --- /dev/null +++ b/layer4/packet_conn_linux.go @@ -0,0 +1,73 @@ +//go:build linux + +package layer4 + +import ( + "bytes" + "encoding/binary" + "net" + "strings" + "syscall" + "unsafe" +) + +const ( + hdrSize = unsafe.Sizeof(syscall.Cmsghdr{}) + oobSize = 128 // enough to hold the local address +) + +var listenConfig = net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + if !strings.HasPrefix(network, "udp") { + return nil + } + + var syscallErr error + err := c.Control(func(fd uintptr) { + // TODO: check if the address is ipv6 only + syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1) + if strings.HasSuffix(network, "6") && syscallErr == nil { + syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_RECVPKTINFO, 1) + } + }) + if err == nil { + err = syscallErr + } + return err + }, +} + +func readFrom(pc net.PacketConn, buf []byte) (int, net.Addr, net.Addr, error) { + if udpConn, ok := pc.(*net.UDPConn); ok { + oob := make([]byte, oobSize) + n, oobN, _, rAddr, err := udpConn.ReadMsgUDP(buf, oob) + if err != nil { + return 0, nil, nil, err + } + if oobN < int(hdrSize) { + return n, rAddr, nil, nil + } + + la := udpConn.LocalAddr().(*net.UDPAddr) + lAddr := &net.UDPAddr{ + IP: la.IP, + Port: la.Port, + Zone: la.Zone, + } + br := bytes.NewReader(oob[:oobN]) + var hdr syscall.Cmsghdr + _ = binary.Read(br, binary.LittleEndian, &hdr) + if hdr.Level == syscall.IPPROTO_IP && hdr.Type == syscall.IP_PKTINFO { + var addr syscall.Inet4Pktinfo + _ = binary.Read(br, binary.LittleEndian, &addr) + lAddr.IP = addr.Addr[:] + } else if hdr.Level == syscall.IPPROTO_IPV6 && hdr.Type == syscall.IPV6_PKTINFO { + var addr syscall.Inet6Pktinfo + _ = binary.Read(br, binary.LittleEndian, &addr) + lAddr.IP = addr.Addr[:] + } + return n, rAddr, lAddr, nil + } + n, addr, err := pc.ReadFrom(buf) + return n, addr, nil, err +} diff --git a/layer4/packet_conn_other.go b/layer4/packet_conn_other.go new file mode 100644 index 00000000..ad98a678 --- /dev/null +++ b/layer4/packet_conn_other.go @@ -0,0 +1,12 @@ +//go:build !linux && !windows + +package layer4 + +import "net" + +var listenConfig = net.ListenConfig{} + +func readFrom(pc net.PacketConn, buf []byte) (int, net.Addr, net.Addr, error) { + n, addr, err := pc.ReadFrom(buf) + return n, addr, nil, err +} diff --git a/layer4/packet_conn_windows.go b/layer4/packet_conn_windows.go new file mode 100644 index 00000000..e12577dc --- /dev/null +++ b/layer4/packet_conn_windows.go @@ -0,0 +1,69 @@ +//go:build windows + +package layer4 + +import ( + "net" + "strings" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + hdrSize = unsafe.Sizeof(windows.WSACMSGHDR{}) + oobSize = 64 // enough to hold the local address +) + +var listenConfig = net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + if !strings.HasPrefix(network, "udp") { + return nil + } + + var syscallErr error + err := c.Control(func(fd uintptr) { + // TODO: check if the address is ipv6 only + syscallErr = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, windows.IP_PKTINFO, 1) + if strings.HasSuffix(network, "6") && syscallErr == nil { + syscallErr = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, windows.IPV6_PKTINFO, 1) + } + }) + if err == nil { + err = syscallErr + } + return err + }, +} + +func readFrom(pc net.PacketConn, buf []byte) (int, net.Addr, net.Addr, error) { + if udpConn, ok := pc.(*net.UDPConn); ok { + oob := make([]byte, oobSize) + n, oobN, _, rAddr, err := udpConn.ReadMsgUDP(buf, oob) + if err != nil { + return 0, nil, nil, err + } + if oobN < int(hdrSize) { + return n, rAddr, nil, nil + } + + la := udpConn.LocalAddr().(*net.UDPAddr) + lAddr := &net.UDPAddr{ + IP: la.IP, + Port: la.Port, + Zone: la.Zone, + } + hdr := (*windows.WSACMSGHDR)(unsafe.Pointer(&oob[0])) + if hdr.Level == windows.IPPROTO_IP && hdr.Type == windows.IP_PKTINFO { + addr := *(*windows.IN_PKTINFO)(unsafe.Pointer(&oob[hdrSize])) + lAddr.IP = addr.Addr[:] + } else if hdr.Level == windows.IPPROTO_IPV6 && hdr.Type == windows.IPV6_PKTINFO { + addr := *(*windows.IN6_PKTINFO)(unsafe.Pointer(&oob[hdrSize])) + lAddr.IP = addr.Addr[:] + } + return n, rAddr, lAddr, nil + } + n, addr, err := pc.ReadFrom(buf) + return n, addr, nil, err +} diff --git a/layer4/server.go b/layer4/server.go index 74d6014f..2a2a8442 100644 --- a/layer4/server.go +++ b/layer4/server.go @@ -99,7 +99,7 @@ func (s *Server) servePacket(pc net.PacketConn) error { go func(packets chan packet) { for { buf := udpBufPool.Get().([]byte) - n, addr, err := pc.ReadFrom(buf) + n, rAddr, lAddr, err := readFrom(pc, buf) if err != nil { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { @@ -111,7 +111,8 @@ func (s *Server) servePacket(pc net.PacketConn) error { packets <- packet{ pooledBuf: buf, n: n, - addr: addr, + rAddr: rAddr, + lAddr: lAddr, } } }(packets) @@ -134,17 +135,19 @@ func (s *Server) servePacket(pc net.PacketConn) error { if pkt.err != nil { return pkt.err } - conn, ok := udpConns[pkt.addr.String()] + addrKey := formatAddrs(pkt.rAddr, pkt.lAddr) + conn, ok := udpConns[addrKey] if !ok { // No existing proxy handler is running for this downstream. // Create one now. conn = &packetConn{ PacketConn: pc, readCh: make(chan *packet, 5), - addr: pkt.addr, + rAddr: pkt.rAddr, + lAddr: pkt.lAddr, closeCh: closeCh, } - udpConns[pkt.addr.String()] = conn + udpConns[addrKey] = conn go func(conn *packetConn) { s.handle(conn) // It might seem cleaner to send to closeCh here rather than @@ -219,6 +222,14 @@ func (s *Server) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { return nil } +// lAddr may be nil for non-udp sockets and unsupported platforms +func formatAddrs(rAddr, lAddr net.Addr) string { + if lAddr == nil { + return rAddr.String() + ":nil" + } + return rAddr.String() + ":" + lAddr.String() +} + type packet struct { // The underlying bytes slice that was gotten from udpBufPool. It's up to // packetConn to return it to udpBufPool once it's consumed. @@ -227,13 +238,18 @@ type packet struct { n int // Error that occurred while reading from socket err error - // Address of downstream - addr net.Addr + // remote address + rAddr net.Addr + // local address, may be nil + lAddr net.Addr } type packetConn struct { net.PacketConn - addr net.Addr + // remote address + rAddr net.Addr + // local address, may be nil + lAddr net.Addr readCh chan *packet closeCh chan string // If not nil, then the previous Read() call didn't consume all the data @@ -325,14 +341,14 @@ func (pc *packetConn) Read(b []byte) (n int, err error) { // Although Close() also does this, we inform the server loop early about // the closure to ensure that if any new packets are received from this // connection in the meantime, a new handler will be started. - pc.closeCh <- pc.addr.String() + pc.closeCh <- formatAddrs(pc.rAddr, pc.lAddr) // Returning EOF here ensures that io.Copy() waiting on the downstream for // reads will terminate. return 0, io.EOF } func (pc *packetConn) Write(b []byte) (n int, err error) { - return pc.PacketConn.WriteTo(b, pc.addr) + return pc.PacketConn.WriteTo(b, pc.rAddr) } func (pc *packetConn) Close() error { @@ -348,13 +364,20 @@ func (pc *packetConn) Close() error { } // We may have already done this earlier in Read(), but just in case // Read() wasn't being called, (re-)notify server loop we're closed. - pc.closeCh <- pc.addr.String() + pc.closeCh <- formatAddrs(pc.rAddr, pc.lAddr) // We don't call net.PacketConn.Close() here as we would stop the UDP // server. return nil } -func (pc *packetConn) RemoteAddr() net.Addr { return pc.addr } +func (pc *packetConn) RemoteAddr() net.Addr { return pc.rAddr } + +func (pc *packetConn) LocalAddr() net.Addr { + if pc.lAddr != nil { + return pc.lAddr + } + return pc.PacketConn.LocalAddr() +} var udpBufPool = sync.Pool{ New: func() interface{} {