@@ -102,6 +102,8 @@ type Session struct {
102
102
// recvDoneCh is closed when recv() exits to avoid a race
103
103
// between stream registration and stream shutdown
104
104
recvDoneCh chan struct {}
105
+ // recvErr is the error the receive loop ended with
106
+ recvErr error
105
107
106
108
// sendDoneCh is closed when send() exits to avoid a race
107
109
// between returning from a Stream.Write and exiting from the send loop
@@ -288,10 +290,18 @@ func (s *Session) AcceptStream() (*Stream, error) {
288
290
// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or
289
291
// if there's unread data in the kernel receive buffer.
290
292
func (s * Session ) Close () error {
291
- return s .close (true , goAwayNormal )
293
+ return s .close (ErrSessionShutdown , true , goAwayNormal )
292
294
}
293
295
294
- func (s * Session ) close (sendGoAway bool , errCode uint32 ) error {
296
+ // CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
297
+ // The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
298
+ // For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
299
+ // receive buffer.
300
+ func (s * Session ) CloseWithError (errCode uint32 ) error {
301
+ return s .close (& GoAwayError {Remote : false , ErrorCode : errCode }, true , errCode )
302
+ }
303
+
304
+ func (s * Session ) close (shutdownErr error , sendGoAway bool , errCode uint32 ) error {
295
305
s .shutdownLock .Lock ()
296
306
defer s .shutdownLock .Unlock ()
297
307
@@ -300,23 +310,25 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
300
310
}
301
311
s .shutdown = true
302
312
if s .shutdownErr == nil {
303
- s .shutdownErr = ErrSessionShutdown
313
+ s .shutdownErr = shutdownErr
304
314
}
305
315
close (s .shutdownCh )
306
316
s .stopKeepalive ()
307
317
308
- // wait for write loop to exit
309
- _ = s .conn .SetWriteDeadline (time .Now ().Add (- 1 * time .Hour )) // if SetWriteDeadline errored, any blocked writes will be unblocked
310
- <- s .sendDoneCh
311
318
if sendGoAway {
319
+ // wait for write loop to exit
320
+ // We need to write the current frame completely before sending a goaway.
321
+ // This will wait for at most s.config.ConnectionWriteTimeout
322
+ <- s .sendDoneCh
312
323
ga := s .goAway (errCode )
313
324
if err := s .conn .SetWriteDeadline (time .Now ().Add (goAwayWaitTime )); err == nil {
314
325
_ , _ = s .conn .Write (ga [:]) // there's nothing we can do on error here
315
326
}
327
+ s .conn .SetWriteDeadline (time.Time {})
316
328
}
317
329
318
- s .conn .SetWriteDeadline (time.Time {})
319
330
s .conn .Close ()
331
+ <- s .sendDoneCh
320
332
<- s .recvDoneCh
321
333
322
334
s .streamLock .Lock ()
@@ -329,17 +341,6 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
329
341
return nil
330
342
}
331
343
332
- // exitErr is used to handle an error that is causing the
333
- // session to terminate.
334
- func (s * Session ) exitErr (err error ) {
335
- s .shutdownLock .Lock ()
336
- if s .shutdownErr == nil {
337
- s .shutdownErr = err
338
- }
339
- s .shutdownLock .Unlock ()
340
- s .close (false , 0 )
341
- }
342
-
343
344
// GoAway can be used to prevent accepting further
344
345
// connections. It does not close the underlying conn.
345
346
func (s * Session ) GoAway () error {
@@ -468,7 +469,7 @@ func (s *Session) startKeepalive() {
468
469
469
470
if err != nil {
470
471
s .logger .Printf ("[ERR] yamux: keepalive failed: %v" , err )
471
- s .exitErr (ErrKeepAliveTimeout )
472
+ s .close (ErrKeepAliveTimeout , false , 0 )
472
473
}
473
474
})
474
475
}
@@ -533,7 +534,18 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
533
534
// send is a long running goroutine that sends data
534
535
func (s * Session ) send () {
535
536
if err := s .sendLoop (); err != nil {
536
- s .exitErr (err )
537
+ // Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
538
+ // received in a GoAway frame received just before the TCP RST that closed the sendLoop
539
+ //
540
+ // Take the shutdownLock to avoid closing the connection concurrently with a Close call.
541
+ s .shutdownLock .Lock ()
542
+ s .conn .Close ()
543
+ <- s .recvDoneCh
544
+ if _ , ok := s .recvErr .(* GoAwayError ); ok {
545
+ err = s .recvErr
546
+ }
547
+ s .shutdownLock .Unlock ()
548
+ s .close (err , false , 0 )
537
549
}
538
550
}
539
551
@@ -661,7 +673,7 @@ func (s *Session) sendLoop() (err error) {
661
673
// recv is a long running goroutine that accepts new data
662
674
func (s * Session ) recv () {
663
675
if err := s .recvLoop (); err != nil {
664
- s .exitErr (err )
676
+ s .close (err , false , 0 )
665
677
}
666
678
}
667
679
@@ -683,7 +695,10 @@ func (s *Session) recvLoop() (err error) {
683
695
err = fmt .Errorf ("panic in yamux receive loop: %s" , rerr )
684
696
}
685
697
}()
686
- defer close (s .recvDoneCh )
698
+ defer func () {
699
+ s .recvErr = err
700
+ close (s .recvDoneCh )
701
+ }()
687
702
var hdr header
688
703
for {
689
704
// fmt.Printf("ReadFull from %#v\n", s.reader)
@@ -799,17 +814,17 @@ func (s *Session) handleGoAway(hdr header) error {
799
814
switch code {
800
815
case goAwayNormal :
801
816
atomic .SwapInt32 (& s .remoteGoAway , 1 )
817
+ // Don't close connection on normal go away. Let the existing streams
818
+ // complete gracefully.
819
+ return nil
802
820
case goAwayProtoErr :
803
821
s .logger .Printf ("[ERR] yamux: received protocol error go away" )
804
- return fmt .Errorf ("yamux protocol error" )
805
822
case goAwayInternalErr :
806
823
s .logger .Printf ("[ERR] yamux: received internal error go away" )
807
- return fmt .Errorf ("remote yamux internal error" )
808
824
default :
809
- s .logger .Printf ("[ERR] yamux: received unexpected go away" )
810
- return fmt .Errorf ("unexpected go away received" )
825
+ s .logger .Printf ("[ERR] yamux: received go away with error code: %d" , code )
811
826
}
812
- return nil
827
+ return & GoAwayError { Remote : true , ErrorCode : code }
813
828
}
814
829
815
830
// incomingStream is used to create a new incoming stream
0 commit comments