Skip to content

Commit

Permalink
Add reconnection mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
caotingv committed Oct 20, 2022
1 parent 9e2023c commit a89f6b5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 2 deletions.
49 changes: 48 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net"
"strings"
"time"

"github.com/kungze/quic-tun/pkg/constants"
"github.com/kungze/quic-tun/pkg/log"
Expand All @@ -22,12 +23,18 @@ type ClientEndpoint struct {
TlsConfig *tls.Config
}

var (
session quic.Session
)

func (c *ClientEndpoint) Start() {
// Dial server endpoint
session, err := quic.DialAddr(c.ServerEndpointSocket, c.TlsConfig, &quic.Config{KeepAlive: true})
var err error
session, err = quic.DialAddr(c.ServerEndpointSocket, c.TlsConfig, &quic.Config{KeepAlive: true})
if err != nil {
panic(err)
}
go c.keepClientWorking()
parent_ctx := context.WithValue(context.TODO(), constants.CtxRemoteEndpointAddr, session.RemoteAddr().String())
// Listen on a TCP or UNIX socket, wait client application's connection request.
localSocket := strings.Split(c.LocalSocket, ":")
Expand All @@ -37,6 +44,7 @@ func (c *ClientEndpoint) Start() {
}
defer listener.Close()
log.Infow("Client endpoint start up successful", "listen address", listener.Addr())

for {
// Accept client application connectin request
conn, err := listener.Accept()
Expand Down Expand Up @@ -77,6 +85,45 @@ func (c *ClientEndpoint) Start() {
}
}

func (c *ClientEndpoint) keepClientWorking() {
stream, err := session.OpenStreamSync(context.Background())
if err != nil {
log.Errorw("Failed to open stream to server endpoint.", "error", err.Error())
return
}
defer stream.Close()

timeTick := time.NewTicker(30 * time.Second)
for {
select {
case <-timeTick.C:
_, err = stream.Write([]byte("ping"))
if err != nil {
log.Errorw("Cannot read write for heartbeat stream.", "error", err.Error())
}
buf := make([]byte, len("pong"))
_, err = stream.Read(buf)
if err != nil {
log.Errorw("Cannot read data for heartbeat stream.", "error", err.Error())
}
if string(buf) != "pong" {
session, err = quic.DialAddr(c.ServerEndpointSocket, c.TlsConfig, &quic.Config{KeepAlive: true})
if err != nil {
log.Errorw("reconnect failed, Retry after 30s...")
break
}
stream, err = session.OpenStreamSync(context.Background())
if err != nil {
log.Errorw("Failed to open stream to server endpoint.", "error", err.Error())
return
}
log.Info("reconnect Success!")
}
default:
}
}
}

func handshake(ctx context.Context, stream *quic.Stream, hsh *tunnel.HandshakeHelper) (bool, *net.Conn) {
logger := log.FromContext(ctx)
logger.Info("Starting handshake with server endpoint")
Expand Down
30 changes: 29 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"net"
"strings"
"time"

"github.com/kungze/quic-tun/pkg/constants"
"github.com/kungze/quic-tun/pkg/log"
Expand Down Expand Up @@ -38,6 +39,9 @@ func (s *ServerEndpoint) Start() {
logger := log.WithValues(constants.ClientEndpointAddr, session.RemoteAddr().String())
logger.Info("A new client endpoint connect request accepted.")
go func() {
// First create the heartbeat steam.
keepClientWorking(session)

for {
// Wait client endpoint open a stream (A new steam means a new tunnel)
stream, err := session.AcceptStream(context.Background())
Expand All @@ -49,7 +53,6 @@ func (s *ServerEndpoint) Start() {
ctx := logger.WithContext(parent_ctx)
hsh := tunnel.NewHandshakeHelper(constants.AckMsgLength, handshake)
hsh.TokenParser = &s.TokenParser

tun := tunnel.NewTunnel(&stream, constants.ServerEndpoint)
tun.Hsh = &hsh
if !tun.HandShake(ctx) {
Expand All @@ -64,6 +67,31 @@ func (s *ServerEndpoint) Start() {
}
}

func keepClientWorking(session quic.Session) {
stream, err := session.AcceptStream(context.Background())
if err != nil {
log.Errorw("Cannot accept heartbeat stream.", "error", err.Error())
}
go func() {
timeTick := time.NewTicker(30 * time.Second)
for {
select {
case <-timeTick.C:
buf := make([]byte, len("ping"))
_, err = stream.Read(buf)
if err != nil {
log.Errorw("Cannot read data for heartbeat stream.", "error", err.Error())
}
_, err = stream.Write([]byte("pong"))
if err != nil {
log.Errorw("Cannot write data for heartbeat stream.", "error", err.Error())
}
default:
}
}
}()
}

func handshake(ctx context.Context, stream *quic.Stream, hsh *tunnel.HandshakeHelper) (bool, *net.Conn) {
logger := log.FromContext(ctx)
logger.Info("Starting handshake with client endpoint")
Expand Down

0 comments on commit a89f6b5

Please sign in to comment.