Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ Usage of ./go-mmproxy:
Path to a file that contains allowed subnets of the proxy servers
-close-after int
Number of seconds after which UDP socket will be cleaned up (default 60)
-dynamic-destination
Traffic will be forwarded to the destination specified in the PROXY protocol header
-l string
Address the proxy listens on (default "0.0.0.0:8443")
-listeners int
Expand Down
3 changes: 2 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func init() {
flag.StringVar(&listenAddrStr, "l", "0.0.0.0:8443", "Address the proxy listens on")
flag.StringVar(&targetAddr4Str, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to")
flag.StringVar(&targetAddr6Str, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to")
flag.BoolVar(&opts.DynamicDestination, "dynamic-destination", false, "Traffic will be forwarded to the destination specified in the PROXY protocol header")
flag.IntVar(&opts.Mark, "mark", 0, "The mark that will be set on outbound packets")
flag.IntVar(&opts.Verbose, "v", 0, `0 - no logging of individual connections
1 - log errors occurring in individual connections
Expand All @@ -44,7 +45,7 @@ func init() {
"Path to a file that contains allowed subnets of the proxy servers")
flag.IntVar(&listeners, "listeners", 1,
"Number of listener sockets that will be opened for the listen address (Linux 3.9+)")
flag.IntVar(&udpCloseAfterInt, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up")
flag.IntVar(&udpCloseAfterInt, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up on inactivity")
}

func listen(ctx context.Context, listenerNum int, parentLogger *slog.Logger, listenErrors chan<- error) {
Expand Down
6 changes: 4 additions & 2 deletions tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,17 @@ func handleConnection(conn net.Conn, opts *utils.Options, logger *slog.Logger) {
return
}

saddr, _, restBytes, err := proxyprotocol.ReadRemoteAddr(buffer[:n], utils.TCP)
saddr, daddr, restBytes, err := proxyprotocol.ReadRemoteAddr(buffer[:n], utils.TCP)
if err != nil {
logger.Debug("failed to parse PROXY header", "error", err, slog.Bool("dropConnection", true))
return
}

targetAddr := opts.TargetAddr6
if saddr.IsValid() {
if saddr.Addr().Is4() {
if opts.DynamicDestination && daddr.IsValid() {
targetAddr = daddr
} else if saddr.Addr().Is4() {
targetAddr = opts.TargetAddr4
}
} else {
Expand Down
55 changes: 52 additions & 3 deletions tests/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestListen(t *testing.T) {
receivedData4 := make(chan listenResult, 1)
go runServer(t, "127.0.0.1:54321", receivedData4)

time.Sleep(1 * time.Second)
time.Sleep(100 * time.Millisecond)

conn, err := net.Dial("tcp", "127.0.0.1:12345")
if err != nil {
Expand Down Expand Up @@ -123,7 +123,7 @@ func TestListen_unknown(t *testing.T) {
receivedData4 := make(chan listenResult, 1)
go runServer(t, "127.0.0.1:54322", receivedData4)

time.Sleep(1 * time.Second)
time.Sleep(100 * time.Millisecond)

conn, err := net.Dial("tcp", "127.0.0.1:12346")
if err != nil {
Expand Down Expand Up @@ -171,7 +171,7 @@ func TestListen_proxyV2(t *testing.T) {
receivedData4 := make(chan listenResult, 1)
go runServer(t, "127.0.0.1:54323", receivedData4)

time.Sleep(1 * time.Second)
time.Sleep(100 * time.Millisecond)

conn, err := net.Dial("tcp", "127.0.0.1:12347")
if err != nil {
Expand Down Expand Up @@ -200,3 +200,52 @@ func TestListen_proxyV2(t *testing.T) {
t.Errorf("Unexpected source address: %v", result.saddr)
}
}

func TestTCPListen_DynamicDestination(t *testing.T) {
opts := utils.Options{
Protocol: utils.TCP,
ListenAddr: netip.MustParseAddrPort("0.0.0.0:12350"),
TargetAddr4: netip.MustParseAddrPort("127.0.0.1:443"),
TargetAddr6: netip.MustParseAddrPort("[::1]:443"),
DynamicDestination: true,
Mark: 0,
AllowedSubnets: nil,
Verbose: 2,
}

lvl := slog.LevelInfo
if opts.Verbose > 0 {
lvl = slog.LevelDebug
}

logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: lvl}))

listenConfig := net.ListenConfig{}
errors := make(chan error, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

go tcp.Listen(ctx, &listenConfig, &opts, logger, errors)

receivedData4 := make(chan listenResult, 1)
go runServer(t, "127.0.0.1:56324", receivedData4)

time.Sleep(100 * time.Millisecond)

conn, err := net.Dial("tcp", "127.0.0.1:12350")
if err != nil {
t.Fatalf("Failed to connect to server: %v", err)
}
defer conn.Close()

conn.Write([]byte("PROXY TCP4 192.168.0.1 127.0.0.1 56324 56324\r\nmoredata"))
result := <-receivedData4

if !reflect.DeepEqual(result.data, []byte("moredata")) {
t.Errorf("Unexpected data: %v", result.data)
}

if result.saddr.String() != "192.168.0.1:56324" {
t.Errorf("Unexpected source address: %v", result.saddr)
}
}
61 changes: 60 additions & 1 deletion tests/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestListenUDP(t *testing.T) {
receivedData4 := make(chan listenResult, 1)
go runUDPServer(t, "127.0.0.1:54323", receivedData4)

time.Sleep(1 * time.Second)
time.Sleep(100 * time.Millisecond)

conn, err := net.Dial("udp", "127.0.0.1:12347")
if err != nil {
Expand Down Expand Up @@ -94,3 +94,62 @@ func TestListenUDP(t *testing.T) {
t.Errorf("Unexpected source address: %v", result.saddr)
}
}

func TestListenUDP_DynamicDestination(t *testing.T) {
opts := utils.Options{
Protocol: utils.UDP,
ListenAddr: netip.MustParseAddrPort("0.0.0.0:12348"),
TargetAddr4: netip.MustParseAddrPort("127.0.0.1:443"),
TargetAddr6: netip.MustParseAddrPort("[::1]:443"),
DynamicDestination: true,
Mark: 0,
AllowedSubnets: nil,
Verbose: 2,
}

lvl := slog.LevelInfo
if opts.Verbose > 0 {
lvl = slog.LevelDebug
}

logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: lvl}))

listenConfig := net.ListenConfig{}
errors := make(chan error, 1)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go udp.Listen(ctx, &listenConfig, &opts, logger, errors)

receivedData4 := make(chan listenResult, 1)
go runUDPServer(t, "127.0.0.1:56324", receivedData4)

time.Sleep(100 * time.Millisecond)

conn, err := net.Dial("udp", "127.0.0.1:12348")
if err != nil {
t.Fatalf("Failed to connect to server: %v", err)
}
defer conn.Close()

buf := []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}
buf = append(buf, 0x21) // PROXY
buf = append(buf, 0x12) // UDP4
buf = append(buf, 0x00, 0x0C) // 12 bytes
buf = append(buf, 192, 168, 0, 1) // saddr
buf = append(buf, 127, 0, 0, 1) // daddr
buf = append(buf, 0xDC, 0x04) // sport 56324
buf = append(buf, 0xDC, 0x04) // sport 56324
buf = append(buf, []byte("moredata")...)

conn.Write(buf)
result := <-receivedData4

if !reflect.DeepEqual(result.data, []byte("moredata")) {
t.Errorf("Unexpected data: %v", result.data)
}

if result.saddr.String() != "192.168.0.1:56324" {
t.Errorf("Unexpected source address: %v", result.saddr)
}
}
12 changes: 7 additions & 5 deletions udp/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,18 @@ func copyFromUpstream(downstream net.PacketConn, conn *connection) {
}
}

func getSocketFromMap(downstream net.PacketConn, opts *utils.Options, downstreamAddr, saddr netip.AddrPort, logger *slog.Logger,
connMap map[netip.AddrPort]*connection, socketClosures chan<- netip.AddrPort) (*connection, error) {
func getSocketFromMap(downstream net.PacketConn, opts *utils.Options, downstreamAddr, saddr, daddr netip.AddrPort,
logger *slog.Logger, connMap map[netip.AddrPort]*connection, socketClosures chan<- netip.AddrPort) (*connection, error) {
if conn := connMap[saddr]; conn != nil {
atomic.AddInt64(conn.lastActivity, 1)
return conn, nil
}

targetAddr := opts.TargetAddr6
if saddr.IsValid() {
if saddr.Addr().Is4() {
if opts.DynamicDestination && daddr.IsValid() {
targetAddr = daddr
} else if saddr.Addr().Is4() {
targetAddr = opts.TargetAddr4
}
} else {
Expand Down Expand Up @@ -162,7 +164,7 @@ func Listen(ctx context.Context, listenConfig *net.ListenConfig, opts *utils.Opt
continue
}

saddr, _, restBytes, err := proxyprotocol.ReadRemoteAddr(buffer[:n], utils.UDP)
saddr, daddr, restBytes, err := proxyprotocol.ReadRemoteAddr(buffer[:n], utils.UDP)
if err != nil {
logger.Debug("failed to parse PROXY header", "error", err, slog.String("remoteAddr", remoteAddr.String()))
continue
Expand All @@ -181,7 +183,7 @@ func Listen(ctx context.Context, listenConfig *net.ListenConfig, opts *utils.Opt
}
}

conn, err := getSocketFromMap(ln, opts, remoteAddr, saddr, logger, connectionMap, socketClosures)
conn, err := getSocketFromMap(ln, opts, remoteAddr, saddr, daddr, logger, connectionMap, socketClosures)
if err != nil {
continue
}
Expand Down
17 changes: 9 additions & 8 deletions utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ const (
)

type Options struct {
Protocol Protocol
ListenAddr netip.AddrPort
TargetAddr4 netip.AddrPort
TargetAddr6 netip.AddrPort
Mark int
Verbose int
AllowedSubnets []netip.Prefix
UDPCloseAfter time.Duration
Protocol Protocol
ListenAddr netip.AddrPort
TargetAddr4 netip.AddrPort
TargetAddr6 netip.AddrPort
DynamicDestination bool
Mark int
Verbose int
AllowedSubnets []netip.Prefix
UDPCloseAfter time.Duration
}

func CheckOriginAllowed(remoteIP netip.Addr, allowedSubnets []netip.Prefix) bool {
Expand Down