Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
module github.com/mholt/caddy-l4

go 1.22.0

toolchain go1.23.0
go 1.23.0

require (
github.com/caddyserver/caddy/v2 v2.8.4
Expand All @@ -14,6 +12,7 @@ require (
go.uber.org/zap v1.27.0
golang.org/x/crypto v0.28.0
golang.org/x/net v0.30.0
golang.org/x/sys v0.33.0
golang.org/x/time v0.7.0
)

Expand Down Expand Up @@ -140,7 +139,6 @@ require (
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
golang.org/x/mod v0.18.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.27.0 // indirect
golang.org/x/term v0.25.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/tools v0.22.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
Expand Down
2 changes: 1 addition & 1 deletion layer4/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (a *App) Provision(ctx caddy.Context) error {
func (a *App) Start() error {
for _, s := range a.Servers {
for _, addr := range s.listenAddrs {
listeners, err := addr.ListenAll(a.ctx, net.ListenConfig{})
listeners, err := addr.ListenAll(a.ctx, listenConfig)
if err != nil {
return err
}
Expand Down
73 changes: 73 additions & 0 deletions layer4/packet_conn_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//go:build linux

package layer4

import (
"bytes"
"encoding/binary"
"net"
"strings"
"syscall"
"unsafe"
)

const (
hdrSize = unsafe.Sizeof(syscall.Cmsghdr{})
oobSize = 128 // enough to hold the local address
)

var listenConfig = net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
if !strings.HasPrefix(network, "udp") {
return nil
}

var syscallErr error
err := c.Control(func(fd uintptr) {
// TODO: check if the address is ipv6 only
syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1)
if strings.HasSuffix(network, "6") && syscallErr == nil {
syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_RECVPKTINFO, 1)
}
})
if err == nil {
err = syscallErr
}
return err
},
}

func readFrom(pc net.PacketConn, buf []byte) (int, net.Addr, net.Addr, error) {
if udpConn, ok := pc.(*net.UDPConn); ok {
oob := make([]byte, oobSize)
n, oobN, _, rAddr, err := udpConn.ReadMsgUDP(buf, oob)
if err != nil {
return 0, nil, nil, err
}
if oobN < int(hdrSize) {
return n, rAddr, nil, nil
}

la := udpConn.LocalAddr().(*net.UDPAddr)
lAddr := &net.UDPAddr{
IP: la.IP,
Port: la.Port,
Zone: la.Zone,
}
br := bytes.NewReader(oob[:oobN])
var hdr syscall.Cmsghdr
_ = binary.Read(br, binary.LittleEndian, &hdr)
if hdr.Level == syscall.IPPROTO_IP && hdr.Type == syscall.IP_PKTINFO {
var addr syscall.Inet4Pktinfo
_ = binary.Read(br, binary.LittleEndian, &addr)
lAddr.IP = addr.Addr[:]
} else if hdr.Level == syscall.IPPROTO_IPV6 && hdr.Type == syscall.IPV6_PKTINFO {
var addr syscall.Inet6Pktinfo
_ = binary.Read(br, binary.LittleEndian, &addr)
lAddr.IP = addr.Addr[:]
}
return n, rAddr, lAddr, nil
}
n, addr, err := pc.ReadFrom(buf)
return n, addr, nil, err
}
12 changes: 12 additions & 0 deletions layer4/packet_conn_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//go:build !linux && !windows

package layer4

import "net"

var listenConfig = net.ListenConfig{}

func readFrom(pc net.PacketConn, buf []byte) (int, net.Addr, net.Addr, error) {
n, addr, err := pc.ReadFrom(buf)
return n, addr, nil, err
}
69 changes: 69 additions & 0 deletions layer4/packet_conn_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
//go:build windows

package layer4

import (
"net"
"strings"
"syscall"
"unsafe"

"golang.org/x/sys/windows"
)

const (
hdrSize = unsafe.Sizeof(windows.WSACMSGHDR{})
oobSize = 64 // enough to hold the local address
)

var listenConfig = net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
if !strings.HasPrefix(network, "udp") {
return nil
}

var syscallErr error
err := c.Control(func(fd uintptr) {
// TODO: check if the address is ipv6 only
syscallErr = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, windows.IP_PKTINFO, 1)
if strings.HasSuffix(network, "6") && syscallErr == nil {
syscallErr = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, windows.IPV6_PKTINFO, 1)
}
})
if err == nil {
err = syscallErr
}
return err
},
}

func readFrom(pc net.PacketConn, buf []byte) (int, net.Addr, net.Addr, error) {
if udpConn, ok := pc.(*net.UDPConn); ok {
oob := make([]byte, oobSize)
n, oobN, _, rAddr, err := udpConn.ReadMsgUDP(buf, oob)
if err != nil {
return 0, nil, nil, err
}
if oobN < int(hdrSize) {
return n, rAddr, nil, nil
}

la := udpConn.LocalAddr().(*net.UDPAddr)
lAddr := &net.UDPAddr{
IP: la.IP,
Port: la.Port,
Zone: la.Zone,
}
hdr := (*windows.WSACMSGHDR)(unsafe.Pointer(&oob[0]))
if hdr.Level == windows.IPPROTO_IP && hdr.Type == windows.IP_PKTINFO {
addr := *(*windows.IN_PKTINFO)(unsafe.Pointer(&oob[hdrSize]))
lAddr.IP = addr.Addr[:]
} else if hdr.Level == windows.IPPROTO_IPV6 && hdr.Type == windows.IPV6_PKTINFO {
addr := *(*windows.IN6_PKTINFO)(unsafe.Pointer(&oob[hdrSize]))
lAddr.IP = addr.Addr[:]
}
return n, rAddr, lAddr, nil
}
n, addr, err := pc.ReadFrom(buf)
return n, addr, nil, err
}
47 changes: 35 additions & 12 deletions layer4/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (s *Server) servePacket(pc net.PacketConn) error {
go func(packets chan packet) {
for {
buf := udpBufPool.Get().([]byte)
n, addr, err := pc.ReadFrom(buf)
n, rAddr, lAddr, err := readFrom(pc, buf)
if err != nil {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
Expand All @@ -111,7 +111,8 @@ func (s *Server) servePacket(pc net.PacketConn) error {
packets <- packet{
pooledBuf: buf,
n: n,
addr: addr,
rAddr: rAddr,
lAddr: lAddr,
}
}
}(packets)
Expand All @@ -134,17 +135,19 @@ func (s *Server) servePacket(pc net.PacketConn) error {
if pkt.err != nil {
return pkt.err
}
conn, ok := udpConns[pkt.addr.String()]
addrKey := formatAddrs(pkt.rAddr, pkt.lAddr)
conn, ok := udpConns[addrKey]
if !ok {
// No existing proxy handler is running for this downstream.
// Create one now.
conn = &packetConn{
PacketConn: pc,
readCh: make(chan *packet, 5),
addr: pkt.addr,
rAddr: pkt.rAddr,
lAddr: pkt.lAddr,
closeCh: closeCh,
}
udpConns[pkt.addr.String()] = conn
udpConns[addrKey] = conn
go func(conn *packetConn) {
s.handle(conn)
// It might seem cleaner to send to closeCh here rather than
Expand Down Expand Up @@ -219,6 +222,14 @@ func (s *Server) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return nil
}

// lAddr may be nil for non-udp sockets and unsupported platforms
func formatAddrs(rAddr, lAddr net.Addr) string {
if lAddr == nil {
return rAddr.String() + ":nil"
}
return rAddr.String() + ":" + lAddr.String()
}

type packet struct {
// The underlying bytes slice that was gotten from udpBufPool. It's up to
// packetConn to return it to udpBufPool once it's consumed.
Expand All @@ -227,13 +238,18 @@ type packet struct {
n int
// Error that occurred while reading from socket
err error
// Address of downstream
addr net.Addr
// remote address
rAddr net.Addr
// local address, may be nil
lAddr net.Addr
}

type packetConn struct {
net.PacketConn
addr net.Addr
// remote address
rAddr net.Addr
// local address, may be nil
lAddr net.Addr
readCh chan *packet
closeCh chan string
// If not nil, then the previous Read() call didn't consume all the data
Expand Down Expand Up @@ -325,14 +341,14 @@ func (pc *packetConn) Read(b []byte) (n int, err error) {
// Although Close() also does this, we inform the server loop early about
// the closure to ensure that if any new packets are received from this
// connection in the meantime, a new handler will be started.
pc.closeCh <- pc.addr.String()
pc.closeCh <- formatAddrs(pc.rAddr, pc.lAddr)
// Returning EOF here ensures that io.Copy() waiting on the downstream for
// reads will terminate.
return 0, io.EOF
}

func (pc *packetConn) Write(b []byte) (n int, err error) {
return pc.PacketConn.WriteTo(b, pc.addr)
return pc.PacketConn.WriteTo(b, pc.rAddr)
}

func (pc *packetConn) Close() error {
Expand All @@ -348,13 +364,20 @@ func (pc *packetConn) Close() error {
}
// We may have already done this earlier in Read(), but just in case
// Read() wasn't being called, (re-)notify server loop we're closed.
pc.closeCh <- pc.addr.String()
pc.closeCh <- formatAddrs(pc.rAddr, pc.lAddr)
// We don't call net.PacketConn.Close() here as we would stop the UDP
// server.
return nil
}

func (pc *packetConn) RemoteAddr() net.Addr { return pc.addr }
func (pc *packetConn) RemoteAddr() net.Addr { return pc.rAddr }

func (pc *packetConn) LocalAddr() net.Addr {
if pc.lAddr != nil {
return pc.lAddr
}
return pc.PacketConn.LocalAddr()
}

var udpBufPool = sync.Pool{
New: func() interface{} {
Expand Down