Skip to content

Commit 4b262c0

Browse files
committed
add CloseWithError
1 parent d8cf4e7 commit 4b262c0

File tree

3 files changed

+83
-33
lines changed

3 files changed

+83
-33
lines changed

errors.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ var (
6464

6565
// ErrSessionShutdown is used if there is a shutdown during
6666
// an operation
67-
ErrSessionShutdown = &Error{msg: "session shutdown"}
67+
ErrSessionShutdown = &GoAwayError{ErrorCode: goAwayNormal, Remote: false}
6868

6969
// ErrStreamsExhausted is returned if we have no more
7070
// stream ids to issue
@@ -87,7 +87,7 @@ var (
8787
ErrUnexpectedFlag = &Error{msg: "unexpected flag"}
8888

8989
// ErrRemoteGoAway is used when we get a go away from the other side
90-
ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"}
90+
ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal}
9191

9292
// ErrStreamReset is sent if a stream is reset. This can happen
9393
// if the backlog is exceeded, or if there was a remote GoAway.

session.go

+42-27
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ type Session struct {
102102
// recvDoneCh is closed when recv() exits to avoid a race
103103
// between stream registration and stream shutdown
104104
recvDoneCh chan struct{}
105+
// recvErr is the error the receive loop ended with
106+
recvErr error
105107

106108
// sendDoneCh is closed when send() exits to avoid a race
107109
// between returning from a Stream.Write and exiting from the send loop
@@ -288,10 +290,18 @@ func (s *Session) AcceptStream() (*Stream, error) {
288290
// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or
289291
// if there's unread data in the kernel receive buffer.
290292
func (s *Session) Close() error {
291-
return s.close(true, goAwayNormal)
293+
return s.close(ErrSessionShutdown, true, goAwayNormal)
292294
}
293295

294-
func (s *Session) close(sendGoAway bool, errCode uint32) error {
296+
// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
297+
// The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
298+
// For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
299+
// receive buffer.
300+
func (s *Session) CloseWithError(errCode uint32) error {
301+
return s.close(&GoAwayError{Remote: false, ErrorCode: errCode}, true, errCode)
302+
}
303+
304+
func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) error {
295305
s.shutdownLock.Lock()
296306
defer s.shutdownLock.Unlock()
297307

@@ -300,23 +310,25 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
300310
}
301311
s.shutdown = true
302312
if s.shutdownErr == nil {
303-
s.shutdownErr = ErrSessionShutdown
313+
s.shutdownErr = shutdownErr
304314
}
305315
close(s.shutdownCh)
306316
s.stopKeepalive()
307317

308-
// wait for write loop to exit
309-
_ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked
310-
<-s.sendDoneCh
311318
if sendGoAway {
319+
// wait for write loop to exit
320+
// We need to write the current frame completely before sending a goaway.
321+
// This will wait for at most s.config.ConnectionWriteTimeout
322+
<-s.sendDoneCh
312323
ga := s.goAway(errCode)
313324
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
314325
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
315326
}
327+
s.conn.SetWriteDeadline(time.Time{})
316328
}
317329

318-
s.conn.SetWriteDeadline(time.Time{})
319330
s.conn.Close()
331+
<-s.sendDoneCh
320332
<-s.recvDoneCh
321333

322334
s.streamLock.Lock()
@@ -329,17 +341,6 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
329341
return nil
330342
}
331343

332-
// exitErr is used to handle an error that is causing the
333-
// session to terminate.
334-
func (s *Session) exitErr(err error) {
335-
s.shutdownLock.Lock()
336-
if s.shutdownErr == nil {
337-
s.shutdownErr = err
338-
}
339-
s.shutdownLock.Unlock()
340-
s.close(false, 0)
341-
}
342-
343344
// GoAway can be used to prevent accepting further
344345
// connections. It does not close the underlying conn.
345346
func (s *Session) GoAway() error {
@@ -468,7 +469,7 @@ func (s *Session) startKeepalive() {
468469

469470
if err != nil {
470471
s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
471-
s.exitErr(ErrKeepAliveTimeout)
472+
s.close(ErrKeepAliveTimeout, false, 0)
472473
}
473474
})
474475
}
@@ -533,7 +534,18 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
533534
// send is a long running goroutine that sends data
534535
func (s *Session) send() {
535536
if err := s.sendLoop(); err != nil {
536-
s.exitErr(err)
537+
// Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
538+
// received in a GoAway frame received just before the TCP RST that closed the sendLoop
539+
//
540+
// Take the shutdownLock to avoid closing the connection concurrently with a Close call.
541+
s.shutdownLock.Lock()
542+
s.conn.Close()
543+
<-s.recvDoneCh
544+
if _, ok := s.recvErr.(*GoAwayError); ok {
545+
err = s.recvErr
546+
}
547+
s.shutdownLock.Unlock()
548+
s.close(err, false, 0)
537549
}
538550
}
539551

@@ -661,7 +673,7 @@ func (s *Session) sendLoop() (err error) {
661673
// recv is a long running goroutine that accepts new data
662674
func (s *Session) recv() {
663675
if err := s.recvLoop(); err != nil {
664-
s.exitErr(err)
676+
s.close(err, false, 0)
665677
}
666678
}
667679

@@ -683,7 +695,10 @@ func (s *Session) recvLoop() (err error) {
683695
err = fmt.Errorf("panic in yamux receive loop: %s", rerr)
684696
}
685697
}()
686-
defer close(s.recvDoneCh)
698+
defer func() {
699+
s.recvErr = err
700+
close(s.recvDoneCh)
701+
}()
687702
var hdr header
688703
for {
689704
// fmt.Printf("ReadFull from %#v\n", s.reader)
@@ -799,17 +814,17 @@ func (s *Session) handleGoAway(hdr header) error {
799814
switch code {
800815
case goAwayNormal:
801816
atomic.SwapInt32(&s.remoteGoAway, 1)
817+
// Don't close connection on normal go away. Let the existing streams
818+
// complete gracefully.
819+
return nil
802820
case goAwayProtoErr:
803821
s.logger.Printf("[ERR] yamux: received protocol error go away")
804-
return fmt.Errorf("yamux protocol error")
805822
case goAwayInternalErr:
806823
s.logger.Printf("[ERR] yamux: received internal error go away")
807-
return fmt.Errorf("remote yamux internal error")
808824
default:
809-
s.logger.Printf("[ERR] yamux: received unexpected go away")
810-
return fmt.Errorf("unexpected go away received")
825+
s.logger.Printf("[ERR] yamux: received go away with error code: %d", code)
811826
}
812-
return nil
827+
return &GoAwayError{Remote: true, ErrorCode: code}
813828
}
814829

815830
// incomingStream is used to create a new incoming stream

session_test.go

+39-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package yamux
33
import (
44
"bytes"
55
"context"
6+
"errors"
67
"fmt"
78
"io"
89
"math/rand"
@@ -39,6 +40,8 @@ type pipeConn struct {
3940
writeDeadline pipeDeadline
4041
writeBlocker chan struct{}
4142
closeCh chan struct{}
43+
closeOnce sync.Once
44+
closeErr error
4245
}
4346

4447
func (p *pipeConn) SetDeadline(t time.Time) error {
@@ -65,10 +68,12 @@ func (p *pipeConn) Write(b []byte) (int, error) {
6568
}
6669

6770
func (p *pipeConn) Close() error {
68-
p.writeDeadline.set(time.Time{})
69-
err := p.Conn.Close()
70-
close(p.closeCh)
71-
return err
71+
p.closeOnce.Do(func() {
72+
p.writeDeadline.set(time.Time{})
73+
p.closeErr = p.Conn.Close()
74+
close(p.closeCh)
75+
})
76+
return p.closeErr
7277
}
7378

7479
func (p *pipeConn) BlockWrites() {
@@ -650,6 +655,35 @@ func TestGoAway(t *testing.T) {
650655
default:
651656
t.Fatalf("err: %v", err)
652657
}
658+
time.Sleep(50 * time.Millisecond)
659+
}
660+
t.Fatalf("expected GoAway error")
661+
}
662+
663+
func TestCloseWithError(t *testing.T) {
664+
// This test is noisy.
665+
conf := testConf()
666+
conf.LogOutput = io.Discard
667+
668+
client, server := testClientServerConfig(conf)
669+
defer client.Close()
670+
defer server.Close()
671+
672+
if err := server.CloseWithError(42); err != nil {
673+
t.Fatalf("err: %v", err)
674+
}
675+
676+
for i := 0; i < 100; i++ {
677+
s, err := client.Open(context.Background())
678+
if err == nil {
679+
s.Close()
680+
time.Sleep(50 * time.Millisecond)
681+
continue
682+
}
683+
if !errors.Is(err, &GoAwayError{ErrorCode: 42, Remote: true}) {
684+
t.Fatalf("err: %v", err)
685+
}
686+
return
653687
}
654688
t.Fatalf("expected GoAway error")
655689
}
@@ -1048,6 +1082,7 @@ func TestKeepAlive_Timeout(t *testing.T) {
10481082
// Prevent the client from responding
10491083
clientConn := client.conn.(*pipeConn)
10501084
clientConn.BlockWrites()
1085+
defer clientConn.UnblockWrites()
10511086

10521087
select {
10531088
case err := <-errCh:

0 commit comments

Comments
 (0)