diff --git a/README.md b/README.md index 7c21194..a8e7099 100644 --- a/README.md +++ b/README.md @@ -104,28 +104,6 @@ client ---> |mtt-client| ---> |mtt-server| ---> destination See [here](#mtt-server-multi-user-version-mtt-mu-server) -## WebSocket Secure - -mos-tls-tunnel support WebSocket Secure protocol (`wss`). WebSocket connections can be proxied by HTTP server such as Apache, as well as most of CDNs that support WebSocket. - -`wss-path` will be the path of HTTP request. - -## Multiplex (Experimental) - -mos-tls-tunnel support connection Multiplex (`mux`). It significantly reduces handshake latency, at the cost of high throughput. - -Client can set `mux-max-stream` to control the maximum number of data streams in one TCP connection. The value should be between 1 and 16. - -if `wss` is enabled, server can automatically detect whether client enable `mux` or not. But you can still use the `mux` to force the server to enable multiplex if auto-detection fails. - -## Self Signed Certificate - -On the server, if both `key` and `cert` is empty, a self signed certificate will be used. And the string from `n` will be certificate's hostname. **This self signed certificate CANNOT be verified.** - -On the client, if server's certificate can't be verified. You can enable `sv` to skip the verification. **Enable this option only if you know what you are doing. Use it with caution.** - -We recommend that you use a valid certificate all the time. A free and valid certificate can be easily obtained here. [Let's Encrypt](https://letsencrypt.org/) - ## Shadowsocks Plugin (SIP003) mos-tls-tunnel support shadowsocks [SIP003](https://shadowsocks.org/en/spec/Plugin.html). Options keys are the same as [Usage](#usage) defined. You don't have to set client and server address: `b`,`d`,`s`, shadowsocks will set those automatically. @@ -154,6 +132,28 @@ Below are example commands with [shadowsocks-libev](https://github.com/shadowsoc The Android plugin project is maintained here: [mostunnel-android](https://github.com/IrineSistiana/mostunnel-android). This is a plugin of [shadowsocks-android](https://github.com/shadowsocks/shadowsocks-android). +## WebSocket Secure + +mos-tls-tunnel support WebSocket Secure protocol (`wss`). WebSocket connections can be proxied by HTTP server such as Apache, as well as most of CDNs that support WebSocket. + +`wss-path` will be the path of HTTP request. + +## Multiplex (Experimental) + +mos-tls-tunnel support connection Multiplex (`mux`). It significantly reduces handshake latency, at the cost of high throughput. + +Client can set `mux-max-stream` to control the maximum number of data streams in one TCP connection. The value should be between 1 and 16. + +if `wss` is enabled, server can automatically detect whether client enable `mux` or not. But you can still use the `mux` to force the server to enable multiplex if auto-detection fails. + +## Self Signed Certificate + +On the server, if both `key` and `cert` is empty, a self signed certificate will be used. And the string from `n` will be certificate's hostname. **This self signed certificate CANNOT be verified.** + +On the client, if server's certificate can't be verified. You can enable `sv` to skip the verification. **Enable this option only if you know what you are doing. Use it with caution.** + +We recommend that you use a valid certificate all the time. A free and valid certificate can be easily obtained here. [Let's Encrypt](https://letsencrypt.org/) + ## mtt-server Multi-user Version (mtt-mu-server) mtt-mu-server allows multiple users to use the `wss` mode of mtt-client to transfer data on the same server port (eg: 443). Users are offloaded to the corresponding backend (`dst` destination) according to the path (`wss-path`) of their HTTP request. @@ -173,6 +173,8 @@ In general, you need the following build dependencies: You might build mos-tls-tunnel like this: +
Example
+ # get source go get -d -u github.com/IrineSistiana/mos-tls-tunnel/cmd/mtt-client go get -d -u github.com/IrineSistiana/mos-tls-tunnel/cmd/mtt-server @@ -182,6 +184,8 @@ You might build mos-tls-tunnel like this: go build -o ./ github.com/IrineSistiana/mos-tls-tunnel/cmd/mtt-client go build -o ./ github.com/IrineSistiana/mos-tls-tunnel/cmd/mtt-server go build -o ./ github.com/IrineSistiana/mos-tls-tunnel/cmd/mtt-mu-server + +
## Open Source Components / Libraries diff --git a/internal/core/client.go b/internal/core/client.go index 6baf153..7bc9d50 100644 --- a/internal/core/client.go +++ b/internal/core/client.go @@ -58,6 +58,9 @@ type Client struct { listener net.Listener log *logrus.Logger + + //test only + testDialServerRaw func() (net.Conn, error) } // NewClient inits a client instance @@ -119,7 +122,7 @@ func NewClient(c *ClientConfig) (*Client, error) { client.wssURL = "wss://" + c.ServerName + c.WSSPath internelDial := func(network, addr string) (net.Conn, error) { // overwrite url host addr - return client.netDialer.Dial(network, c.RemoteAddr) + return client.dialServerRaw() } client.wsDialer = &websocket.Dialer{ TLSClientConfig: client.tlsConf, @@ -191,6 +194,7 @@ func (client *Client) Start() error { //ForwardConn forwards this connection to server. //It will block until server-side connection is closed +//or c is closed func (client *Client) ForwardConn(c net.Conn) error { var rightConn net.Conn var err error @@ -201,7 +205,7 @@ func (client *Client) ForwardConn(c net.Conn) error { return fmt.Errorf("mux getStream: %v", err) } } else { - rightConn, err = client.newServerConn() + rightConn, err = client.dialServer() if err != nil { return fmt.Errorf("connect to remote: %v", err) } @@ -230,10 +234,11 @@ func (client *Client) dialWSS() (net.Conn, error) { } func (client *Client) dialTLS() (net.Conn, error) { - conn, err := tls.DialWithDialer(client.netDialer, "tcp", client.conf.RemoteAddr, client.tlsConf) + raw, err := client.dialServerRaw() if err != nil { return nil, err } + conn := tls.Client(raw, client.tlsConf) if err := conn.Handshake(); err != nil { conn.Close() return nil, err @@ -241,19 +246,26 @@ func (client *Client) dialTLS() (net.Conn, error) { return conn, nil } -func (client *Client) newServerConn() (net.Conn, error) { +func (client *Client) dialServer() (net.Conn, error) { if client.conf.EnableWSS { return client.dialWSS() } return client.dialTLS() } +func (client *Client) dialServerRaw() (net.Conn, error) { + if client.testDialServerRaw == nil { + return client.netDialer.Dial("tcp", client.conf.RemoteAddr) + } + return client.testDialServerRaw() +} + type smuxSessPool struct { sync.Map } func (client *Client) dialNewSmuxSess() (*smux.Session, error) { - rightConn, err := client.newServerConn() + rightConn, err := client.dialServer() if err != nil { return nil, err } diff --git a/internal/core/cmux.go b/internal/core/cmux.go index a4f3b12..b531afc 100644 --- a/internal/core/cmux.go +++ b/internal/core/cmux.go @@ -1,3 +1,22 @@ +// Copyright (c) 2019-2020 IrineSistiana +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + package core import ( diff --git a/internal/core/common.go b/internal/core/common.go index eb78d08..d7eea22 100644 --- a/internal/core/common.go +++ b/internal/core/common.go @@ -34,9 +34,13 @@ import ( "github.com/xtaci/smux" ) -var ioCopybuffPool = &sync.Pool{New: func() interface{} { - return make([]byte, defaultCopyIOBufferSize) -}} +var ( + ioCopybuffPool = &sync.Pool{New: func() interface{} { + return make([]byte, defaultCopyIOBufferSize) + }} + + longTimeAgo = time.Unix(0, 1) +) func acquireIOBuf() []byte { return ioCopybuffPool.Get().([]byte) @@ -73,40 +77,52 @@ func (fe *firstErr) getErr() error { // both of them. func openTunnel(a, b net.Conn, timeout time.Duration) error { fe := firstErr{} - - go openOneWayTunnel(a, b, timeout, &fe) - openOneWayTunnel(b, a, timeout, &fe) + muTimeout := atomic.Value{} + muTimeout.Store(timeout) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + openOneWayTunnel(a, b, &muTimeout, &fe) + wg.Done() + }() + openOneWayTunnel(b, a, &muTimeout, &fe) + wg.Wait() return fe.getErr() } // don not use this func, use openTunnel instead -func openOneWayTunnel(dst, src net.Conn, timeout time.Duration, fe *firstErr) { +func openOneWayTunnel(dst, src net.Conn, muTimeout *atomic.Value, fe *firstErr) { buf := acquireIOBuf() - _, err := copyBuffer(dst, src, buf, timeout) - // a nil err is io.EOF err, which is surpressed by copyBuffer. + _, err := copyBuffer(dst, src, buf, muTimeout) + + // a nil err might be an io.EOF err, which is surpressed by copyBuffer. // report a nil err means one conn was closed by peer. fe.report(err) //let another goroutine break from copy loop - dst.Close() + muTimeout.Store(time.Duration(0)) + src.SetDeadline(longTimeAgo) + dst.SetDeadline(longTimeAgo) src.Close() + dst.Close() releaseIOBuf(buf) } -func copyBuffer(dst net.Conn, src net.Conn, buf []byte, timeout time.Duration) (written int64, err error) { +func copyBuffer(dst net.Conn, src net.Conn, buf []byte, muTimeout *atomic.Value) (written int64, err error) { if len(buf) <= 0 { panic("buf size <= 0") } for { - src.SetReadDeadline(time.Now().Add(timeout)) + src.SetReadDeadline(time.Now().Add(muTimeout.Load().(time.Duration))) nr, er := src.Read(buf) if nr > 0 { - dst.SetWriteDeadline(time.Now().Add(timeout)) + dst.SetWriteDeadline(time.Now().Add(muTimeout.Load().(time.Duration))) nw, ew := dst.Write(buf[0:nr]) if nw > 0 { written += int64(nw) diff --git a/internal/core/core_test.go b/internal/core/core_test.go index b8c4cb0..a339065 100644 --- a/internal/core/core_test.go +++ b/internal/core/core_test.go @@ -55,11 +55,15 @@ type dstServer struct { l net.Listener } -func runDstServer(addr string, echo bool) (*dstServer, error) { - l, err := net.Listen("tcp", addr) - if err != nil { - return nil, err +func runDstServer(addr string, l net.Listener, echo bool) (*dstServer, error) { + if l == nil { + var err error + l, err = net.Listen("tcp", addr) + if err != nil { + return nil, err + } } + e := &dstServer{ l: l, } @@ -105,7 +109,10 @@ func (e *dstServer) close() error { } func test(sc *ServerConfig, cc *ClientConfig, t *testing.T) { - echo, err := runDstServer(dstAddr, true) + dummyConnL2C := newDummyDialerListener() + dummyConnS2D := newDummyDialerListener() + + echo, err := runDstServer("", dummyConnS2D, true) if err != nil { t.Fatal(err) } @@ -117,20 +124,23 @@ func test(sc *ServerConfig, cc *ClientConfig, t *testing.T) { if err != nil { t.Fatal(err) } - defer client.Close() - wg.Add(1) - go func() { - fmt.Printf("client exited [%v]", client.Start()) - wg.Done() - }() + client.testDialServerRaw = dummyConnL2C.connect + + // wg.Add(1) + // go func() { + // fmt.Printf("client exited [%v]", client.Start()) + // wg.Done() + // }() + // defer client.Close() server, err := NewServer(serverTestConfig) if err != nil { t.Fatal(err) } + server.testDialDst = dummyConnS2D.connect wg.Add(1) go func() { - fmt.Printf("server exited [%v]", server.Start()) + fmt.Printf("server exited [%v]", server.ActiveAndServe(dummyConnL2C)) wg.Done() }() defer server.Close() @@ -145,19 +155,24 @@ func test(sc *ServerConfig, cc *ClientConfig, t *testing.T) { t.Fatal(err) } - // 5 clients, 50 connections per client - for g := 0; g < 5; g++ { - wgLocalConn := sync.WaitGroup{} - for i := 0; i < 50; i++ { - wgLocalConn.Add(1) + for g := 0; g < 1; g++ { + wgClient := sync.WaitGroup{} + for i := 0; i < 1; i++ { + wgClient.Add(1) go func() { - defer wgLocalConn.Done() - localConn, err := net.Dial("tcp", clientBindAddr) - if err != nil { - t.Fatal(err) - } + defer wgClient.Done() + localConn, clientConn := net.Pipe() defer localConn.Close() - localConn.SetDeadline(time.Now().Add(time.Second * 30)) + defer clientConn.Close() + wgClient.Add(1) + go func() { + defer wgClient.Done() + err := client.ForwardConn(clientConn) + if err != nil { + t.Log(err) + } + }() + localConn.SetDeadline(time.Now().Add(time.Second * 10)) if _, err := localConn.Write(garbage); err != nil { t.Fatal(err) } @@ -172,7 +187,7 @@ func test(sc *ServerConfig, cc *ClientConfig, t *testing.T) { } }() } - wgLocalConn.Wait() + wgClient.Wait() } // force to close so wg can be released @@ -185,6 +200,14 @@ func Test_plain(t *testing.T) { test(serverTestConfig, clientTestConfig, t) } +func Test_mux(t *testing.T) { + serverTestConfig.EnableWSS = false + clientTestConfig.EnableWSS = false + serverTestConfig.EnableMux = true + clientTestConfig.EnableMux = true + test(serverTestConfig, clientTestConfig, t) +} + func Test_wss(t *testing.T) { serverTestConfig.EnableWSS = true serverTestConfig.WSSPath = "/" @@ -218,7 +241,10 @@ func Test_wss_auto_mux(t *testing.T) { func bench(sc *ServerConfig, cc *ClientConfig, b *testing.B) (conn net.Conn) { - echo, err := runDstServer(dstAddr, false) + dummyConnL2C := newDummyDialerListener() + dummyConnS2D := newDummyDialerListener() + + echo, err := runDstServer("", dummyConnS2D, false) if err != nil { b.Fatal(err) } @@ -230,31 +256,33 @@ func bench(sc *ServerConfig, cc *ClientConfig, b *testing.B) (conn net.Conn) { if err != nil { b.Fatal(err) } - defer client.Close() - wg.Add(1) - go func() { - fmt.Printf("client exited [%v]", client.Start()) - wg.Done() - }() + client.testDialServerRaw = dummyConnL2C.connect + // server server, err := NewServer(serverTestConfig) if err != nil { b.Fatal(err) } + server.testDialDst = dummyConnS2D.connect wg.Add(1) go func() { - fmt.Printf("server exited [%v]", server.Start()) + fmt.Printf("server exited [%v]", server.ActiveAndServe(dummyConnL2C)) wg.Done() }() defer server.Close() + // wait server and client time.Sleep(500 * time.Millisecond) - localConn, err := net.Dial("tcp", clientBindAddr) - if err != nil { - b.Fatal(err) - } + localConn, clientConn := net.Pipe() defer localConn.Close() + defer clientConn.Close() + go func() { + err := client.ForwardConn(clientConn) + if err != nil { + log.Printf("client forward: %v", err) + } + }() garbage := make([]byte, 64*1024) _, err = rand.Read(garbage) @@ -265,12 +293,16 @@ func bench(sc *ServerConfig, cc *ClientConfig, b *testing.B) (conn net.Conn) { b.ReportAllocs() b.ResetTimer() t := time.Now() + c := 0 for i := 0; i < b.N; i++ { - localConn.Write(garbage) + n, err := localConn.Write(garbage) + if err != nil { + b.Fatalf("write data @ %d %.2f%%, %v", c, float64(c)/float64(b.N*64*1024), err) + } + c = c + n } b.StopTimer() - - b.Logf("[%f kb/s]", float64(b.N*64)/time.Since(t).Seconds()) + b.Logf("[%dM %.2f Mb/s]", c/1024/1024, float64(c/1024/1024)/time.Since(t).Seconds()) // force to close so wg can be released client.Close() diff --git a/internal/core/dummy_conn.go b/internal/core/dummy_conn.go new file mode 100644 index 0000000..a217f85 --- /dev/null +++ b/internal/core/dummy_conn.go @@ -0,0 +1,77 @@ +// Copyright (c) 2019-2020 IrineSistiana +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package core + +import ( + "io" + "net" + "sync" +) + +type dummyDialerListener struct { + notifyListener chan net.Conn + closeOnce sync.Once + isClosed chan struct{} +} + +func newDummyDialerListener() *dummyDialerListener { + return &dummyDialerListener{ + notifyListener: make(chan net.Conn, 0), + isClosed: make(chan struct{}, 0), + } +} + +func (d *dummyDialerListener) Dial(network string, address string) (net.Conn, error) { + return d.connect() +} + +func (d *dummyDialerListener) connect() (net.Conn, error) { + c1, c2 := net.Pipe() + select { + case d.notifyListener <- c2: + return c1, nil + case <-d.isClosed: + return nil, io.ErrClosedPipe + } +} + +func (d *dummyDialerListener) Accept() (net.Conn, error) { + select { + case c := <-d.notifyListener: + return c, nil + case <-d.isClosed: + return nil, io.ErrClosedPipe + } +} + +func (d *dummyDialerListener) Addr() net.Addr { + return pipeAddr{} +} + +func (d *dummyDialerListener) Close() error { + d.closeOnce.Do(func() { close(d.isClosed) }) + return nil +} + +// steal from golang net pipeAddr +type pipeAddr struct{} + +func (pipeAddr) Network() string { return "pipe" } +func (pipeAddr) String() string { return "pipe" } diff --git a/internal/core/multi_user_server_test.go b/internal/core/multi_user_server_test.go index 2e1b868..65f48c7 100644 --- a/internal/core/multi_user_server_test.go +++ b/internal/core/multi_user_server_test.go @@ -25,6 +25,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/ioutil" "net" "net/http" @@ -41,7 +42,7 @@ var ( ) func Test_MU(t *testing.T) { - echo, err := runDstServer(muDstAddr, true) + echo, err := runDstServer(muDstAddr, nil, true) if err != nil { t.Fatal(err) } @@ -132,18 +133,18 @@ func Test_MU(t *testing.T) { t.Fatal(err) } - localConn.SetWriteDeadline(time.Now().Add(time.Second)) + localConn.SetDeadline(time.Now().Add(time.Second)) + //test write if _, err := localConn.Write(garbage); err != nil { return fmt.Errorf("write to client: %v", err) } - //test read - buf := make([]byte, garbageSize) - _, err = localConn.Read(buf) + b, err := ioutil.ReadAll(io.LimitReader(localConn, int64(garbageSize))) if err != nil { return fmt.Errorf("read from client: %v", err) } - if !bytes.Equal(buf, garbage) { + + if !bytes.Equal(b, garbage) { t.Fatal("data err") } diff --git a/internal/core/server.go b/internal/core/server.go index 5d0411e..b8686eb 100644 --- a/internal/core/server.go +++ b/internal/core/server.go @@ -60,6 +60,9 @@ type Server struct { smuxConfig *smux.Config log *logrus.Logger + + //test only + testDialDst func() (net.Conn, error) } func NewServer(c *ServerConfig) (*Server, error) { @@ -138,11 +141,15 @@ func (server *Server) Start() error { if err != nil { return fmt.Errorf("listener.Listen: %v", err) } + defer l.Close() + + return server.ActiveAndServe(l) +} + +func (server *Server) ActiveAndServe(l net.Listener) error { if !server.conf.DisableTLS { l = tls.NewListener(l, server.tlsConf) } - defer l.Close() - server.listenerLocker.Lock() server.listener = l server.listenerLocker.Unlock() @@ -151,7 +158,7 @@ func (server *Server) Start() error { if server.conf.EnableWSS { httpMux := http.NewServeMux() httpMux.Handle(server.conf.WSSPath, server) - err = http.Serve(server.listener, httpMux) + err := http.Serve(server.listener, httpMux) if err != nil { return fmt.Errorf("http.Serve: %v", err) } @@ -206,7 +213,10 @@ func (server *Server) handleClientConn(leftConn net.Conn, requestEntry *logrus.E } defer rightConn.Close() - openTunnel(rightConn, leftConn, server.conf.Timeout) + err = openTunnel(rightConn, leftConn, server.conf.Timeout) + if err != nil { + requestEntry.Errorf("openTunnel, %v", err) + } } func (server *Server) handleClientMuxConn(leftConn net.Conn, requestEntry *logrus.Entry) { @@ -240,6 +250,9 @@ func (server *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (server *Server) dialDst() (net.Conn, error) { + if server.testDialDst != nil { + return server.testDialDst() + } return server.netDialer.Dial("tcp", server.conf.DstAddr) } diff --git a/internal/core/websocket.go b/internal/core/websocket.go index a3e68f8..216ac98 100644 --- a/internal/core/websocket.go +++ b/internal/core/websocket.go @@ -22,6 +22,7 @@ package core import ( "io" "net" + "sync" "time" "github.com/gorilla/websocket" @@ -37,8 +38,9 @@ var ( // webSocketConnWrapper is a wrapper for net.Conn over WebSocket connection. type webSocketConnWrapper struct { - ws *websocket.Conn - reader io.Reader + ws *websocket.Conn + reader io.Reader + closeOnce sync.Once } func wrapWebSocketConn(c *websocket.Conn) net.Conn { @@ -81,7 +83,15 @@ func (c *webSocketConnWrapper) Write(b []byte) (int, error) { } func (c *webSocketConnWrapper) Close() error { - c.ws.WriteMessage(websocket.CloseMessage, websocketFormatCloseMessage) + return c.CloseWithDeadLine(time.Millisecond * 100) +} + +func (c *webSocketConnWrapper) CloseWithDeadLine(t time.Duration) error { + c.closeOnce.Do(func() { + // set WriteDeadline to avoid sub conn blocking here forever!! + c.ws.SetWriteDeadline(time.Now().Add(t)) + c.ws.WriteMessage(websocket.CloseMessage, websocketFormatCloseMessage) + }) return c.ws.Close() }