Skip to content

Commit

Permalink
Fix race between Conn.Close and Conn.Handshake
Browse files Browse the repository at this point in the history
  • Loading branch information
Danielius1922 authored and Sean-Der committed Aug 22, 2024
1 parent 032d60c commit 1a02350
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 10 deletions.
40 changes: 30 additions & 10 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ type Conn struct {

handshakeCompletedSuccessfully atomic.Value
handshakeMutex sync.Mutex
handshakeDone chan struct{}

encryptedPackets []addrPkt

connectionClosedByUser bool
closeLock sync.Mutex
closed *closer.Closer
handshakeLoopsFinished sync.WaitGroup

readDeadline *deadline.Deadline
writeDeadline *deadline.Deadline
Expand Down Expand Up @@ -256,6 +256,12 @@ func (c *Conn) HandshakeContext(ctx context.Context) error {
return nil
}

handshakeDone := make(chan struct{})
defer close(handshakeDone)
c.closeLock.Lock()
c.handshakeDone = handshakeDone
c.closeLock.Unlock()

// rfc5246#section-7.4.3
// In addition, the hash and signature algorithms MUST be compatible
// with the key in the server's end-entity certificate.
Expand Down Expand Up @@ -405,7 +411,12 @@ func (c *Conn) Write(p []byte) (int, error) {
// Close closes the connection.
func (c *Conn) Close() error {
err := c.close(true) //nolint:contextcheck
c.handshakeLoopsFinished.Wait()
c.closeLock.Lock()
handshakeDone := c.handshakeDone
c.closeLock.Unlock()
if handshakeDone != nil {
<-handshakeDone
}
return err
}

Expand Down Expand Up @@ -1026,7 +1037,6 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh

done := make(chan struct{})
ctxRead, cancelRead := context.WithCancel(context.Background())
c.cancelHandshakeReader = cancelRead
cfg.onFlightState = func(_ flightVal, s handshakeState) {
if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
c.setHandshakeCompletedSuccessfully()
Expand All @@ -1035,16 +1045,21 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
}

ctxHs, cancel := context.WithCancel(context.Background())

c.closeLock.Lock()
c.cancelHandshaker = cancel
c.cancelHandshakeReader = cancelRead
c.closeLock.Unlock()

firstErr := make(chan error, 1)

c.handshakeLoopsFinished.Add(2)
var handshakeLoopsFinished sync.WaitGroup
handshakeLoopsFinished.Add(2)

// Handshake routine should be live until close.
// The other party may request retransmission of the last flight to cope with packet drop.
go func() {
defer c.handshakeLoopsFinished.Done()
defer handshakeLoopsFinished.Done()
err := c.fsm.Run(ctxHs, c, initialState)
if !errors.Is(err, context.Canceled) {
select {
Expand All @@ -1064,7 +1079,7 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
// Force stop handshaker when the underlying connection is closed.
cancel()
}()
defer c.handshakeLoopsFinished.Done()
defer handshakeLoopsFinished.Done()
for {
if err := c.readAndBuffer(ctxRead); err != nil {
var e *alertError
Expand Down Expand Up @@ -1123,12 +1138,12 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
case err := <-firstErr:
cancelRead()
cancel()
c.handshakeLoopsFinished.Wait()
handshakeLoopsFinished.Wait()
return c.translateHandshakeCtxError(err)
case <-ctx.Done():
cancelRead()
cancel()
c.handshakeLoopsFinished.Wait()
handshakeLoopsFinished.Wait()
return c.translateHandshakeCtxError(ctx.Err())
case <-done:
return nil
Expand All @@ -1146,8 +1161,13 @@ func (c *Conn) translateHandshakeCtxError(err error) error {
}

func (c *Conn) close(byUser bool) error {
c.cancelHandshaker()
c.cancelHandshakeReader()
c.closeLock.Lock()
cancelHandshaker := c.cancelHandshaker
cancelHandshakeReader := c.cancelHandshakeReader
c.closeLock.Unlock()

cancelHandshaker()
cancelHandshakeReader()

if c.isHandshakeCompletedSuccessfully() && byUser {
// Discard error from notify() to return non-error on the first user call of Close()
Expand Down
51 changes: 51 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3662,3 +3662,54 @@ func TestMultiHandshake(t *testing.T) {
t.Fatal(err)
}
}

func TestCloseDuringHandshake(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 10).Stop()

serverCert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Fatal(err)
}

for i := 0; i < 100; i++ {
_, cb := dpipe.Pipe()
server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{
Certificates: []tls.Certificate{serverCert},
})
if err != nil {
t.Fatal(err)
}

waitChan := make(chan struct{})
go func() {
close(waitChan)
_ = server.Handshake()
}()

<-waitChan
if err = server.Close(); err != nil {
t.Fatal(err)
}
}
}

func TestCloseWithoutHandshake(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 10).Stop()

serverCert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Fatal(err)
}
_, cb := dpipe.Pipe()
server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{
Certificates: []tls.Certificate{serverCert},
})
if err != nil {
t.Fatal(err)
}
if err = server.Close(); err != nil {
t.Fatal(err)
}
}

0 comments on commit 1a02350

Please sign in to comment.