@@ -102,6 +102,8 @@ type Session struct {
102
102
// recvDoneCh is closed when recv() exits to avoid a race
103
103
// between stream registration and stream shutdown
104
104
recvDoneCh chan struct {}
105
+ // recvErr is the error the receive loop ended with
106
+ recvErr error
105
107
106
108
// sendDoneCh is closed when send() exits to avoid a race
107
109
// between returning from a Stream.Write and exiting from the send loop
@@ -288,10 +290,18 @@ func (s *Session) AcceptStream() (*Stream, error) {
288
290
// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or
289
291
// if there's unread data in the kernel receive buffer.
290
292
func (s * Session ) Close () error {
291
- return s .close ( true , goAwayNormal )
293
+ return s .closeWithGoAway ( goAwayNormal )
292
294
}
293
295
294
- func (s * Session ) close (sendGoAway bool , errCode uint32 ) error {
296
+ // CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
297
+ // The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
298
+ // For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
299
+ // receive buffer.
300
+ func (s * Session ) CloseWithError (errCode uint32 ) error {
301
+ return s .closeWithGoAway (errCode )
302
+ }
303
+
304
+ func (s * Session ) closeWithGoAway (errCode uint32 ) error {
295
305
s .shutdownLock .Lock ()
296
306
defer s .shutdownLock .Unlock ()
297
307
@@ -300,22 +310,25 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
300
310
}
301
311
s .shutdown = true
302
312
if s .shutdownErr == nil {
303
- s .shutdownErr = ErrSessionShutdown
313
+ if errCode == goAwayNormal {
314
+ s .shutdownErr = ErrSessionShutdown
315
+ } else {
316
+ s .shutdownErr = & GoAwayError {Remote : false , ErrorCode : errCode }
317
+ }
304
318
}
305
319
close (s .shutdownCh )
306
320
s .stopKeepalive ()
307
321
308
322
// wait for write loop to exit
309
- _ = s .conn .SetWriteDeadline (time .Now ().Add (- 1 * time .Hour )) // if SetWriteDeadline errored, any blocked writes will be unblocked
323
+ // We need to complete writing the current frame before sending a goaway
324
+ // This will wait for at most s.config.ConnectionWriteTimeout
310
325
<- s .sendDoneCh
311
- if sendGoAway {
312
- ga := s .goAway (errCode )
313
- if err := s .conn .SetWriteDeadline (time .Now ().Add (goAwayWaitTime )); err == nil {
314
- _ , _ = s .conn .Write (ga [:]) // there's nothing we can do on error here
315
- }
326
+ ga := s .goAway (errCode )
327
+ if err := s .conn .SetWriteDeadline (time .Now ().Add (goAwayWaitTime )); err == nil {
328
+ _ , _ = s .conn .Write (ga [:]) // there's nothing we can do on error here
316
329
}
317
-
318
330
s .conn .SetWriteDeadline (time.Time {})
331
+
319
332
s .conn .Close ()
320
333
<- s .recvDoneCh
321
334
@@ -329,15 +342,30 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
329
342
return nil
330
343
}
331
344
332
- // exitErr is used to handle an error that is causing the
333
- // session to terminate.
334
- func (s * Session ) exitErr (err error ) {
345
+ func (s * Session ) closeWithoutGoAway (err error ) error {
335
346
s .shutdownLock .Lock ()
347
+ defer s .shutdownLock .Unlock ()
348
+ if s .shutdown {
349
+ return nil
350
+ }
351
+ s .shutdown = true
336
352
if s .shutdownErr == nil {
337
353
s .shutdownErr = err
338
354
}
339
- s .shutdownLock .Unlock ()
340
- s .close (false , 0 )
355
+ close (s .shutdownCh )
356
+ s .conn .Close ()
357
+ <- s .sendDoneCh
358
+ <- s .recvDoneCh
359
+ s .stopKeepalive ()
360
+
361
+ s .streamLock .Lock ()
362
+ defer s .streamLock .Unlock ()
363
+ for id , stream := range s .streams {
364
+ stream .forceClose ()
365
+ delete (s .streams , id )
366
+ stream .memorySpan .Done ()
367
+ }
368
+ return nil
341
369
}
342
370
343
371
// GoAway can be used to prevent accepting further
@@ -468,7 +496,7 @@ func (s *Session) startKeepalive() {
468
496
469
497
if err != nil {
470
498
s .logger .Printf ("[ERR] yamux: keepalive failed: %v" , err )
471
- s .exitErr (ErrKeepAliveTimeout )
499
+ s .closeWithoutGoAway (ErrKeepAliveTimeout )
472
500
}
473
501
})
474
502
}
@@ -533,7 +561,20 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
533
561
// send is a long running goroutine that sends data
534
562
func (s * Session ) send () {
535
563
if err := s .sendLoop (); err != nil {
536
- s .exitErr (err )
564
+ // Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
565
+ // received in a GoAway frame received just before the TCP RST that closed the sendLoop
566
+ s .shutdownLock .Lock ()
567
+ if s .shutdownErr == nil {
568
+ s .conn .Close ()
569
+ <- s .recvDoneCh
570
+ if _ , ok := s .recvErr .(* GoAwayError ); ok {
571
+ s .shutdownErr = s .recvErr
572
+ } else {
573
+ s .shutdownErr = err
574
+ }
575
+ }
576
+ s .shutdownLock .Unlock ()
577
+ s .closeWithoutGoAway (err )
537
578
}
538
579
}
539
580
@@ -661,7 +702,7 @@ func (s *Session) sendLoop() (err error) {
661
702
// recv is a long running goroutine that accepts new data
662
703
func (s * Session ) recv () {
663
704
if err := s .recvLoop (); err != nil {
664
- s .exitErr (err )
705
+ s .closeWithoutGoAway (err )
665
706
}
666
707
}
667
708
@@ -683,7 +724,10 @@ func (s *Session) recvLoop() (err error) {
683
724
err = fmt .Errorf ("panic in yamux receive loop: %s" , rerr )
684
725
}
685
726
}()
686
- defer close (s .recvDoneCh )
727
+ defer func () {
728
+ s .recvErr = err
729
+ close (s .recvDoneCh )
730
+ }()
687
731
var hdr header
688
732
for {
689
733
// fmt.Printf("ReadFull from %#v\n", s.reader)
@@ -799,17 +843,17 @@ func (s *Session) handleGoAway(hdr header) error {
799
843
switch code {
800
844
case goAwayNormal :
801
845
atomic .SwapInt32 (& s .remoteGoAway , 1 )
846
+ // Don't close connection on normal go away. Let the existing streams
847
+ // complete gracefully.
848
+ return nil
802
849
case goAwayProtoErr :
803
850
s .logger .Printf ("[ERR] yamux: received protocol error go away" )
804
- return fmt .Errorf ("yamux protocol error" )
805
851
case goAwayInternalErr :
806
852
s .logger .Printf ("[ERR] yamux: received internal error go away" )
807
- return fmt .Errorf ("remote yamux internal error" )
808
853
default :
809
- s .logger .Printf ("[ERR] yamux: received unexpected go away" )
810
- return fmt .Errorf ("unexpected go away received" )
854
+ s .logger .Printf ("[ERR] yamux: received go away with error code: %d" , code )
811
855
}
812
- return nil
856
+ return & GoAwayError { Remote : true , ErrorCode : code }
813
857
}
814
858
815
859
// incomingStream is used to create a new incoming stream
0 commit comments