Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
IrineSistiana committed Jul 17, 2021
1 parent 944e452 commit 3941273
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 100 deletions.
39 changes: 23 additions & 16 deletions core/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
type Client struct {
Listener net.Listener
ServerAddr string
NoTLS bool
Auth string
ServerName string
CertPool *x509.CertPool
Expand All @@ -52,11 +53,13 @@ func (c *Client) ActiveAndServe() error {
Control: GetControlFunc(&TcpConfig{AndroidVPN: c.AndroidVPNMode, EnableTFO: c.TFO}),
}

c.tlsConfig = new(tls.Config)
c.tlsConfig.NextProtos = []string{"http/1.1", "h2"}
c.tlsConfig.ServerName = c.ServerName
c.tlsConfig.RootCAs = c.CertPool
c.tlsConfig.InsecureSkipVerify = c.InsecureSkipVerify
if !c.NoTLS {
c.tlsConfig = new(tls.Config)
c.tlsConfig.NextProtos = []string{"http/1.1", "h2"}
c.tlsConfig.ServerName = c.ServerName
c.tlsConfig.RootCAs = c.CertPool
c.tlsConfig.InsecureSkipVerify = c.InsecureSkipVerify
}

if len(c.Auth) > 0 {
c.auth = md5.Sum([]byte(c.Auth))
Expand All @@ -71,6 +74,7 @@ func (c *Client) ActiveAndServe() error {
if err != nil {
return fmt.Errorf("l.Accept(): %w", err)
}
reduceLoopbackSocketBuf(localConn)

go func() {
defer localConn.Close()
Expand Down Expand Up @@ -100,22 +104,25 @@ func (c *Client) ActiveAndServe() error {
}
}

func (c *Client) dialServerConn() (serverConn net.Conn, err error) {
serverRawConn, err := c.dialer.Dial("tcp", c.ServerAddr)
func (c *Client) dialServerConn() (net.Conn, error) {
serverConn, err := c.dialer.Dial("tcp", c.ServerAddr)
if err != nil {
return nil, err
}

serverTLSConn := tls.Client(serverRawConn, c.tlsConfig)
if err := tls13HandshakeWithTimeout(serverTLSConn, time.Second*5); err != nil {
serverRawConn.Close()
return nil, err
if !c.NoTLS {
serverTLSConn := tls.Client(serverConn, c.tlsConfig)
if err := tls13HandshakeWithTimeout(serverTLSConn, time.Second*5); err != nil {
serverTLSConn.Close()
return nil, err
}
serverConn = serverTLSConn
}

// write auth
if len(c.Auth) > 0 {
if _, err := serverTLSConn.Write(c.auth[:]); err != nil {
serverRawConn.Close()
if _, err := serverConn.Write(c.auth[:]); err != nil {
serverConn.Close()
return nil, fmt.Errorf("failed to write auth: %w", err)
}
}
Expand All @@ -125,10 +132,10 @@ func (c *Client) dialServerConn() (serverConn net.Conn, err error) {
if c.Mux > 0 {
mode = modeMux
}
if _, err := serverTLSConn.Write([]byte{mode}); err != nil {
serverRawConn.Close()
if _, err := serverConn.Write([]byte{mode}); err != nil {
serverConn.Close()
return nil, fmt.Errorf("failed to write mode: %w", err)
}

return serverTLSConn, nil
return serverConn, nil
}
14 changes: 8 additions & 6 deletions core/core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func Test_main(t *testing.T) {
}()

// test1
test := func(t *testing.T, mux int) {
test := func(t *testing.T, mux int, noTLS bool) {
// start server
_, keyPEM, certPEM, err := GenerateCertificate("example.com")
cert, err := tls.X509KeyPair(certPEM, keyPEM)
Expand Down Expand Up @@ -141,16 +141,18 @@ func Test_main(t *testing.T) {
}

tests := []struct {
name string
mux int
name string
mux int
noTLS bool
}{
{"plain", 0},
{"mux", 5},
{"plain", 0, false},
{"mux", 5, false},
{"no tls", 5, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
test(t, tt.mux)
test(t, tt.mux, tt.noTLS)
})
}
}
56 changes: 32 additions & 24 deletions core/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,83 +33,90 @@ import (
type Server struct {
Listener net.Listener
Dst string
NoTLS bool
Auth string

Certificates []tls.Certificate
Timeout time.Duration

auth [16]byte
auth [16]byte
tlsConfig *tls.Config
}

func (s *Server) ActiveAndServe() error {
tlsConfig := new(tls.Config)
tlsConfig.NextProtos = []string{"h2"}
tlsConfig.Certificates = s.Certificates
if !s.NoTLS {
s.tlsConfig = new(tls.Config)
s.tlsConfig.NextProtos = []string{"h2"}
s.tlsConfig.Certificates = s.Certificates
}

if len(s.Auth) > 0 {
s.auth = md5.Sum([]byte(s.Auth))
}

for {
clientRawConn, err := s.Listener.Accept()
clientConn, err := s.Listener.Accept()
if err != nil {
return fmt.Errorf("l.Accept(): %w", err)
}

go func() {
clientTLSConn := tls.Server(clientRawConn, tlsConfig)
defer clientTLSConn.Close()
defer clientConn.Close()

// handshake
if err := tls13HandshakeWithTimeout(clientTLSConn, time.Second*5); err != nil {
log.Printf("ERROR: %s, tls13HandshakeWithTimeout: %v", clientRawConn.RemoteAddr(), err)
return
if !s.NoTLS {
clientTLSConn := tls.Server(clientConn, s.tlsConfig)
// handshake
if err := tls13HandshakeWithTimeout(clientTLSConn, time.Second*5); err != nil {
log.Printf("ERROR: %s, tls13HandshakeWithTimeout: %v", clientConn.RemoteAddr(), err)
return
}
clientConn = clientTLSConn
}

// check auth
if len(s.Auth) > 0 {
auth := make([]byte, 16)
if _, err := io.ReadFull(clientTLSConn, auth); err != nil {
log.Printf("ERROR: %s, read client auth header: %v", clientRawConn.RemoteAddr(), err)
if _, err := io.ReadFull(clientConn, auth); err != nil {
log.Printf("ERROR: %s, read client auth header: %v", clientConn.RemoteAddr(), err)
return
}

if !bytes.Equal(s.auth[:], auth) {
log.Printf("ERROR: %s, auth failed", clientRawConn.RemoteAddr())
discard(clientTLSConn)
log.Printf("ERROR: %s, auth failed", clientConn.RemoteAddr())
discardRead(clientConn, time.Second*15)
return
}
}

// mode
header := make([]byte, 1)
if _, err := io.ReadFull(clientTLSConn, header); err != nil {
log.Printf("ERROR: %s, read client mode header: %v", clientRawConn.RemoteAddr(), err)
if _, err := io.ReadFull(clientConn, header); err != nil {
log.Printf("ERROR: %s, read client mode header: %v", clientConn.RemoteAddr(), err)
return
}

switch header[0] {
case modePlain:
if err := s.handleClientConn(clientTLSConn); err != nil {
log.Printf("ERROR: %s, handleClientConn: %v", clientRawConn.RemoteAddr(), err)
if err := s.handleClientConn(clientConn); err != nil {
log.Printf("ERROR: %s, handleClientConn: %v", clientConn.RemoteAddr(), err)
return
}
case modeMux:
err := s.handleClientMux(clientTLSConn)
err := s.handleClientMux(clientConn)
if err != nil {
log.Printf("ERROR: %s, handleClientMux: %v", clientRawConn.RemoteAddr(), err)
log.Printf("ERROR: %s, handleClientMux: %v", clientConn.RemoteAddr(), err)
return
}
default:
log.Printf("ERROR: %s, invalid header %d", clientRawConn.RemoteAddr(), header[0])
log.Printf("ERROR: %s, invalid header %d", clientConn.RemoteAddr(), header[0])
return
}
}()
}
}

func discard(c net.Conn) {
c.SetDeadline(time.Now().Add(time.Second * 15))
func discardRead(c net.Conn, t time.Duration) {
c.SetDeadline(time.Now().Add(t))
buf := make([]byte, 512)
for {
_, err := c.Read(buf)
Expand All @@ -124,6 +131,7 @@ func (s *Server) handleClientConn(cc net.Conn) (err error) {
if err != nil {
return fmt.Errorf("net.Dial: %v", err)
}
reduceLoopbackSocketBuf(dstConn)
defer dstConn.Close()

if err := ctunnel.OpenTunnel(dstConn, cc, s.Timeout); err != nil {
Expand Down
1 change: 1 addition & 0 deletions core/smux.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ func (m *muxPool) dialSessLocked() (call *dialCall) {
if err != nil {
call.err = err
close(call.done)
return
}

sess, err := smux.Client(c, muxConfig)
Expand Down
14 changes: 14 additions & 0 deletions core/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,21 @@

package core

import "net"

type TcpConfig struct {
AndroidVPN bool
EnableTFO bool
}

func reduceLoopbackSocketBuf(c net.Conn) {
tcpConn, ok := c.(*net.TCPConn)
if ok && isLoopbackConn(tcpConn) {
tcpConn.SetReadBuffer(32 * 1024)
tcpConn.SetWriteBuffer(32 * 1024)
}
}

func isLoopbackConn(c *net.TCPConn) bool {
return c.LocalAddr().(*net.TCPAddr).IP.IsLoopback() || c.RemoteAddr().(*net.TCPAddr).IP.IsLoopback()
}
6 changes: 3 additions & 3 deletions core/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ func GenerateCertificate(serverName string) (dnsName string, keyPEM, certPEM []b
SerialNumber: serialNumber,
Subject: pkix.Name{CommonName: dnsName},
DNSNames: []string{dnsName},

NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),

KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
Expand All @@ -82,6 +81,7 @@ func GenerateCertificate(serverName string) (dnsName string, keyPEM, certPEM []b
if err != nil {
return
}

b, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
return
Expand Down
Loading

0 comments on commit 3941273

Please sign in to comment.