diff --git a/cmd/warp/main.go b/cmd/warp/main.go index 33747ef..8b3b44f 100644 --- a/cmd/warp/main.go +++ b/cmd/warp/main.go @@ -18,6 +18,7 @@ var ( builtBy = "" ip = flag.String("ip", "127.0.0.1", "listen ip") port = flag.Int("port", 0, "listen port") + oip = flag.String("oip", "", "outbound ip") opr = flag.String("opr", "", "outbound port range: 12000-12500") verFlag = flag.Bool("version", false, "show build version") oprRe = regexp.MustCompile(`^([1-9][0-9]{0,5})-([1-9][0-9]{0,5})$`) @@ -33,7 +34,11 @@ func main() { return } - w := &warp.Server{Addr: *ip, Port: *port} + w := &warp.Server{ + Addr: *ip, + Port: *port, + OutboundAddr: *oip, + } trimedOpr := strings.TrimSpace(*opr) if trimedOpr != "" { diff --git a/server.go b/server.go index 90a018c..2a261e6 100644 --- a/server.go +++ b/server.go @@ -15,6 +15,7 @@ type Server struct { Addr string Port int Hooks []Hook + OutboundAddr string OutboundPorts *PortRange log *log.Logger } @@ -63,13 +64,27 @@ func (s *Server) Start() error { } } +func (s *Server) getOutboundAddrAndPort() (string, int) { + port := 0 + addr := s.Addr + + if s.OutboundPorts != nil { + port, _ = s.OutboundPorts.TakeOut() + } + if s.OutboundAddr != "" { + addr = s.OutboundAddr + } + + return addr, port +} + func (s *Server) HandleConnection(conn net.Conn) { uuid := GenID().String() s.log.Printf("%s %s connected from %s", uuid, onPxy, conn.RemoteAddr()) raddr, err := s.OriginalAddrDst(conn) if err != nil { - s.log.Printf("%s %s original addr error: %#v", uuid, onPxy, err) + s.log.Printf("%s %s original addr error: %s(%#v)", uuid, onPxy, err.Error(), err) return } @@ -87,16 +102,8 @@ func (s *Server) HandleConnection(conn net.Conn) { } }() - outboundPort := 0 - if s.OutboundPorts != nil { - outboundPort, err = s.OutboundPorts.TakeOut() - if err != nil { - s.log.Printf("%s %s outbound ports take out error: %s(%#v)", uuid, onPxy, err.Error(), err) - return - } - } - - laddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", s.Addr, outboundPort)) + oAddr, oPort := s.getOutboundAddrAndPort() + laddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", oAddr, oPort)) if err != nil { s.log.Printf("%s %s resolve tcp addr error: %s(%#v)", uuid, onPxy, err.Error(), err) return