From 394127395b4f081d1fa792602a65f70af5eeb49c Mon Sep 17 00:00:00 2001 From: IrineSistiana <49315432+IrineSistiana@users.noreply.github.com> Date: Sat, 17 Jul 2021 17:07:45 +0800 Subject: [PATCH] update --- core/client.go | 39 +++++++++------- core/core_test.go | 14 +++--- core/server.go | 56 +++++++++++++---------- core/smux.go | 1 + core/tcp.go | 14 ++++++ core/utils.go | 6 +-- main.go | 110 +++++++++++++++++++++++++--------------------- 7 files changed, 140 insertions(+), 100 deletions(-) diff --git a/core/client.go b/core/client.go index f20898a..149f7ca 100644 --- a/core/client.go +++ b/core/client.go @@ -31,6 +31,7 @@ import ( type Client struct { Listener net.Listener ServerAddr string + NoTLS bool Auth string ServerName string CertPool *x509.CertPool @@ -52,11 +53,13 @@ func (c *Client) ActiveAndServe() error { Control: GetControlFunc(&TcpConfig{AndroidVPN: c.AndroidVPNMode, EnableTFO: c.TFO}), } - c.tlsConfig = new(tls.Config) - c.tlsConfig.NextProtos = []string{"http/1.1", "h2"} - c.tlsConfig.ServerName = c.ServerName - c.tlsConfig.RootCAs = c.CertPool - c.tlsConfig.InsecureSkipVerify = c.InsecureSkipVerify + if !c.NoTLS { + c.tlsConfig = new(tls.Config) + c.tlsConfig.NextProtos = []string{"http/1.1", "h2"} + c.tlsConfig.ServerName = c.ServerName + c.tlsConfig.RootCAs = c.CertPool + c.tlsConfig.InsecureSkipVerify = c.InsecureSkipVerify + } if len(c.Auth) > 0 { c.auth = md5.Sum([]byte(c.Auth)) @@ -71,6 +74,7 @@ func (c *Client) ActiveAndServe() error { if err != nil { return fmt.Errorf("l.Accept(): %w", err) } + reduceLoopbackSocketBuf(localConn) go func() { defer localConn.Close() @@ -100,22 +104,25 @@ func (c *Client) ActiveAndServe() error { } } -func (c *Client) dialServerConn() (serverConn net.Conn, err error) { - serverRawConn, err := c.dialer.Dial("tcp", c.ServerAddr) +func (c *Client) dialServerConn() (net.Conn, error) { + serverConn, err := c.dialer.Dial("tcp", c.ServerAddr) if err != nil { return nil, err } - serverTLSConn := tls.Client(serverRawConn, c.tlsConfig) - if err := tls13HandshakeWithTimeout(serverTLSConn, time.Second*5); err != nil { - serverRawConn.Close() - return nil, err + if !c.NoTLS { + serverTLSConn := tls.Client(serverConn, c.tlsConfig) + if err := tls13HandshakeWithTimeout(serverTLSConn, time.Second*5); err != nil { + serverTLSConn.Close() + return nil, err + } + serverConn = serverTLSConn } // write auth if len(c.Auth) > 0 { - if _, err := serverTLSConn.Write(c.auth[:]); err != nil { - serverRawConn.Close() + if _, err := serverConn.Write(c.auth[:]); err != nil { + serverConn.Close() return nil, fmt.Errorf("failed to write auth: %w", err) } } @@ -125,10 +132,10 @@ func (c *Client) dialServerConn() (serverConn net.Conn, err error) { if c.Mux > 0 { mode = modeMux } - if _, err := serverTLSConn.Write([]byte{mode}); err != nil { - serverRawConn.Close() + if _, err := serverConn.Write([]byte{mode}); err != nil { + serverConn.Close() return nil, fmt.Errorf("failed to write mode: %w", err) } - return serverTLSConn, nil + return serverConn, nil } diff --git a/core/core_test.go b/core/core_test.go index 8a9a362..c654067 100644 --- a/core/core_test.go +++ b/core/core_test.go @@ -68,7 +68,7 @@ func Test_main(t *testing.T) { }() // test1 - test := func(t *testing.T, mux int) { + test := func(t *testing.T, mux int, noTLS bool) { // start server _, keyPEM, certPEM, err := GenerateCertificate("example.com") cert, err := tls.X509KeyPair(certPEM, keyPEM) @@ -141,16 +141,18 @@ func Test_main(t *testing.T) { } tests := []struct { - name string - mux int + name string + mux int + noTLS bool }{ - {"plain", 0}, - {"mux", 5}, + {"plain", 0, false}, + {"mux", 5, false}, + {"no tls", 5, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - test(t, tt.mux) + test(t, tt.mux, tt.noTLS) }) } } diff --git a/core/server.go b/core/server.go index 9d0ce3a..b0ccb6d 100644 --- a/core/server.go +++ b/core/server.go @@ -33,83 +33,90 @@ import ( type Server struct { Listener net.Listener Dst string + NoTLS bool Auth string Certificates []tls.Certificate Timeout time.Duration - auth [16]byte + auth [16]byte + tlsConfig *tls.Config } func (s *Server) ActiveAndServe() error { - tlsConfig := new(tls.Config) - tlsConfig.NextProtos = []string{"h2"} - tlsConfig.Certificates = s.Certificates + if !s.NoTLS { + s.tlsConfig = new(tls.Config) + s.tlsConfig.NextProtos = []string{"h2"} + s.tlsConfig.Certificates = s.Certificates + } if len(s.Auth) > 0 { s.auth = md5.Sum([]byte(s.Auth)) } for { - clientRawConn, err := s.Listener.Accept() + clientConn, err := s.Listener.Accept() if err != nil { return fmt.Errorf("l.Accept(): %w", err) } go func() { - clientTLSConn := tls.Server(clientRawConn, tlsConfig) - defer clientTLSConn.Close() + defer clientConn.Close() - // handshake - if err := tls13HandshakeWithTimeout(clientTLSConn, time.Second*5); err != nil { - log.Printf("ERROR: %s, tls13HandshakeWithTimeout: %v", clientRawConn.RemoteAddr(), err) - return + if !s.NoTLS { + clientTLSConn := tls.Server(clientConn, s.tlsConfig) + // handshake + if err := tls13HandshakeWithTimeout(clientTLSConn, time.Second*5); err != nil { + log.Printf("ERROR: %s, tls13HandshakeWithTimeout: %v", clientConn.RemoteAddr(), err) + return + } + clientConn = clientTLSConn } // check auth if len(s.Auth) > 0 { auth := make([]byte, 16) - if _, err := io.ReadFull(clientTLSConn, auth); err != nil { - log.Printf("ERROR: %s, read client auth header: %v", clientRawConn.RemoteAddr(), err) + if _, err := io.ReadFull(clientConn, auth); err != nil { + log.Printf("ERROR: %s, read client auth header: %v", clientConn.RemoteAddr(), err) return } if !bytes.Equal(s.auth[:], auth) { - log.Printf("ERROR: %s, auth failed", clientRawConn.RemoteAddr()) - discard(clientTLSConn) + log.Printf("ERROR: %s, auth failed", clientConn.RemoteAddr()) + discardRead(clientConn, time.Second*15) return } } // mode header := make([]byte, 1) - if _, err := io.ReadFull(clientTLSConn, header); err != nil { - log.Printf("ERROR: %s, read client mode header: %v", clientRawConn.RemoteAddr(), err) + if _, err := io.ReadFull(clientConn, header); err != nil { + log.Printf("ERROR: %s, read client mode header: %v", clientConn.RemoteAddr(), err) return } switch header[0] { case modePlain: - if err := s.handleClientConn(clientTLSConn); err != nil { - log.Printf("ERROR: %s, handleClientConn: %v", clientRawConn.RemoteAddr(), err) + if err := s.handleClientConn(clientConn); err != nil { + log.Printf("ERROR: %s, handleClientConn: %v", clientConn.RemoteAddr(), err) return } case modeMux: - err := s.handleClientMux(clientTLSConn) + err := s.handleClientMux(clientConn) if err != nil { - log.Printf("ERROR: %s, handleClientMux: %v", clientRawConn.RemoteAddr(), err) + log.Printf("ERROR: %s, handleClientMux: %v", clientConn.RemoteAddr(), err) return } default: - log.Printf("ERROR: %s, invalid header %d", clientRawConn.RemoteAddr(), header[0]) + log.Printf("ERROR: %s, invalid header %d", clientConn.RemoteAddr(), header[0]) return } }() } } -func discard(c net.Conn) { - c.SetDeadline(time.Now().Add(time.Second * 15)) +func discardRead(c net.Conn, t time.Duration) { + c.SetDeadline(time.Now().Add(t)) buf := make([]byte, 512) for { _, err := c.Read(buf) @@ -124,6 +131,7 @@ func (s *Server) handleClientConn(cc net.Conn) (err error) { if err != nil { return fmt.Errorf("net.Dial: %v", err) } + reduceLoopbackSocketBuf(dstConn) defer dstConn.Close() if err := ctunnel.OpenTunnel(dstConn, cc, s.Timeout); err != nil { diff --git a/core/smux.go b/core/smux.go index 7c5444c..436a5df 100644 --- a/core/smux.go +++ b/core/smux.go @@ -121,6 +121,7 @@ func (m *muxPool) dialSessLocked() (call *dialCall) { if err != nil { call.err = err close(call.done) + return } sess, err := smux.Client(c, muxConfig) diff --git a/core/tcp.go b/core/tcp.go index 7b63d0d..c9dbda5 100644 --- a/core/tcp.go +++ b/core/tcp.go @@ -17,7 +17,21 @@ package core +import "net" + type TcpConfig struct { AndroidVPN bool EnableTFO bool } + +func reduceLoopbackSocketBuf(c net.Conn) { + tcpConn, ok := c.(*net.TCPConn) + if ok && isLoopbackConn(tcpConn) { + tcpConn.SetReadBuffer(32 * 1024) + tcpConn.SetWriteBuffer(32 * 1024) + } +} + +func isLoopbackConn(c *net.TCPConn) bool { + return c.LocalAddr().(*net.TCPAddr).IP.IsLoopback() || c.RemoteAddr().(*net.TCPAddr).IP.IsLoopback() +} diff --git a/core/utils.go b/core/utils.go index 07b76a8..cd47183 100644 --- a/core/utils.go +++ b/core/utils.go @@ -69,9 +69,8 @@ func GenerateCertificate(serverName string) (dnsName string, keyPEM, certPEM []b SerialNumber: serialNumber, Subject: pkix.Name{CommonName: dnsName}, DNSNames: []string{dnsName}, - - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, @@ -82,6 +81,7 @@ func GenerateCertificate(serverName string) (dnsName string, keyPEM, certPEM []b if err != nil { return } + b, err := x509.MarshalPKCS8PrivateKey(key) if err != nil { return diff --git a/main.go b/main.go index f06717d..ec82798 100644 --- a/main.go +++ b/main.go @@ -51,7 +51,7 @@ func main() { }() var bindAddr, dstAddr, auth, serverName, cca, ca, cert, key string - var insecureSkipVerify, isServer, tfo, vpn, genCert, showVersion bool + var noTLS, insecureSkipVerify, isServer, tfo, vpn, genCert, showVersion bool var cpu, mux int var timeout time.Duration var timeoutFlag int @@ -61,6 +61,7 @@ func main() { commandLine.StringVar(&bindAddr, "b", "", "[Host:Port] bind address") commandLine.StringVar(&dstAddr, "d", "", "[Host:Port] destination address") commandLine.StringVar(&auth, "auth", "", "server password") + commandLine.BoolVar(&noTLS, "no-tls", false, "disable TLS (debug only)") // client only commandLine.IntVar(&mux, "mux", 0, "enable mux") @@ -76,7 +77,7 @@ func main() { commandLine.StringVar(&key, "key", "", "[Path] PEM key file") // etc - commandLine.IntVar(&timeoutFlag, "t", 300, "timeout after sec") + commandLine.IntVar(&timeoutFlag, "t", 300, "timeout in sec") commandLine.BoolVar(&tfo, "fast-open", false, "enable tfo, only available on linux 4.11+") commandLine.IntVar(&cpu, "cpu", runtime.NumCPU(), "the maximum number of CPUs that can be executing simultaneously") @@ -167,14 +168,16 @@ func main() { setStrIfNotEmpty(&dstAddr, s) s, _ = sip003Args.SS_PLUGIN_OPTIONS["auth"] setStrIfNotEmpty(&auth, s) - s, _ = sip003Args.SS_PLUGIN_OPTIONS["mux"] - if err := setIntIfNotZero(&mux, s); err != nil { - log.Fatalf("main: invalid mux value, %v", err) - } + _, ok = sip003Args.SS_PLUGIN_OPTIONS["no-tls"] + noTLS = noTLS || ok // client s, _ = sip003Args.SS_PLUGIN_OPTIONS["n"] setStrIfNotEmpty(&serverName, s) + s, _ = sip003Args.SS_PLUGIN_OPTIONS["mux"] + if err := setIntIfNotZero(&mux, s); err != nil { + log.Fatalf("main: invalid mux value, %v", err) + } s, _ = sip003Args.SS_PLUGIN_OPTIONS["ca"] setStrIfNotEmpty(&ca, s) s, _ = sip003Args.SS_PLUGIN_OPTIONS["cca"] @@ -225,29 +228,30 @@ func main() { if isServer { var certificates []tls.Certificate - - switch { - case len(cert) == 0 && len(key) == 0: // no cert and key - log.Printf("main: warnning: neither -key nor -cert is specified") - - dnsName, keyPEM, certPEM, err := core.GenerateCertificate(serverName) - if err != nil { - log.Fatalf("main: generateCertificate: %v", err) - } - log.Printf("main: warnning: using tmp certificate %s", dnsName) - cer, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - log.Fatalf("main: X509KeyPair: %v", err) - } - certificates = []tls.Certificate{cer} - case len(cert) != 0 && len(key) != 0: // has cert and key - cer, err := tls.LoadX509KeyPair(cert, key) //load cert - if err != nil { - log.Fatalf("main: LoadX509KeyPair: %v", err) + if !noTLS { + switch { + case len(cert) == 0 && len(key) == 0: // no cert and key + log.Printf("main: warnning: neither -key nor -cert is specified") + + dnsName, keyPEM, certPEM, err := core.GenerateCertificate(serverName) + if err != nil { + log.Fatalf("main: generateCertificate: %v", err) + } + log.Printf("main: warnning: using tmp certificate %s", dnsName) + cer, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + log.Fatalf("main: X509KeyPair: %v", err) + } + certificates = []tls.Certificate{cer} + case len(cert) != 0 && len(key) != 0: // has cert and key + cer, err := tls.LoadX509KeyPair(cert, key) //load cert + if err != nil { + log.Fatalf("main: LoadX509KeyPair: %v", err) + } + certificates = []tls.Certificate{cer} + default: + log.Fatal("main: server must have a X509 key pair, aka. -cert and -key") } - certificates = []tls.Certificate{cer} - default: - log.Fatal("main: server must have a X509 key pair, aka. -cert and -key") } lc := net.ListenConfig{Control: core.GetControlFunc(&core.TcpConfig{EnableTFO: tfo})} @@ -258,8 +262,9 @@ func main() { server := core.Server{ Listener: l, - Auth: auth, Dst: dstAddr, + NoTLS: noTLS, + Auth: auth, Certificates: certificates, Timeout: timeout, } @@ -270,31 +275,33 @@ func main() { } } else { // do client - if len(serverName) == 0 { - serverName = strings.SplitN(dstAddr, ":", 2)[0] - } var rootCAs *x509.CertPool - - switch { - case len(cca) != 0: - cca = strings.TrimRight(cca, "=") - pem, err := base64.RawStdEncoding.DecodeString(cca) - if err != nil { - log.Fatalf("main: base64.RawStdEncoding.DecodeString: %v", err) + if !noTLS { + if len(serverName) == 0 { + serverName = strings.SplitN(dstAddr, ":", 2)[0] } - rootCAs = x509.NewCertPool() - if ok := rootCAs.AppendCertsFromPEM(pem); !ok { - log.Fatal("main: AppendCertsFromPEM failed, cca is invalid") - } - case len(ca) != 0: - rootCAs = x509.NewCertPool() - certPEMBlock, err := ioutil.ReadFile(ca) - if err != nil { - log.Fatalf("main: ReadFile ca [%s], %v", ca, err) - } - if ok := rootCAs.AppendCertsFromPEM(certPEMBlock); !ok { - log.Fatal("main: AppendCertsFromPEM failed, ca is invalid") + switch { + case len(cca) != 0: + cca = strings.TrimRight(cca, "=") + pem, err := base64.RawStdEncoding.DecodeString(cca) + if err != nil { + log.Fatalf("main: base64.RawStdEncoding.DecodeString: %v", err) + } + + rootCAs = x509.NewCertPool() + if ok := rootCAs.AppendCertsFromPEM(pem); !ok { + log.Fatal("main: AppendCertsFromPEM failed, cca is invalid") + } + case len(ca) != 0: + rootCAs = x509.NewCertPool() + certPEMBlock, err := ioutil.ReadFile(ca) + if err != nil { + log.Fatalf("main: ReadFile ca [%s], %v", ca, err) + } + if ok := rootCAs.AppendCertsFromPEM(certPEMBlock); !ok { + log.Fatal("main: AppendCertsFromPEM failed, ca is invalid") + } } } @@ -307,6 +314,7 @@ func main() { client := core.Client{ Listener: l, ServerAddr: dstAddr, + NoTLS: noTLS, Auth: auth, ServerName: serverName, CertPool: rootCAs,