Skip to content

Commit b523bdd

Browse files
committed
review comments
1 parent ede18a5 commit b523bdd

File tree

4 files changed

+33
-36
lines changed

4 files changed

+33
-36
lines changed

const.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func (e *GoAwayError) Temporary() bool {
4545

4646
func (e *GoAwayError) Is(target error) bool {
4747
// to maintain compatibility with errors returned by previous versions
48-
if e.Remote && target == ErrRemoteGoAwayNormal {
48+
if e.Remote && target == ErrRemoteGoAway {
4949
return true
5050
} else if !e.Remote && target == ErrSessionShutdown {
5151
return true
@@ -114,8 +114,9 @@ var (
114114
// ErrUnexpectedFlag is set when we get an unexpected flag
115115
ErrUnexpectedFlag = &Error{msg: "unexpected flag"}
116116

117-
// ErrRemoteGoAwayNormal is used when we get a go away from the other side
118-
ErrRemoteGoAwayNormal = &GoAwayError{Remote: true, ErrorCode: goAwayNormal}
117+
// ErrRemoteGoAway is used when we get a go away from the other side with error code
118+
// goAwayNormal(0).
119+
ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal}
119120

120121
// ErrStreamReset is sent if a stream is reset. This can happen
121122
// if the backlog is exceeded, or if there was a remote GoAway.

session.go

+9-9
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,6 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) {
205205
if s.IsClosed() {
206206
return nil, s.shutdownErr
207207
}
208-
if atomic.LoadInt32(&s.remoteGoAwayNormal) == 1 {
209-
return nil, ErrRemoteGoAwayNormal
210-
}
211208

212209
// Block if we have too many inflight SYNs
213210
select {
@@ -535,8 +532,14 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
535532
// send is a long running goroutine that sends data
536533
func (s *Session) send() {
537534
if err := s.sendLoop(); err != nil {
538-
// Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
539-
// received in a GoAway frame received just before the TCP RST that closed the sendLoop
535+
// If we are shutting down because remote closed the connection, prefer the recvLoop error
536+
// over the sendLoop error. The receive loop might have error code received in a GoAway frame,
537+
// which was received just before the TCP RST that closed the sendLoop.
538+
//
539+
// If we are closing because of an write error, we use the error from the sendLoop and not the recvLoop.
540+
// We hold the shutdownLock, close the connection, and wait for the receive loop to finish and
541+
// use the sendLoop error. Holding the shutdownLock ensures that the recvLoop doesn't trigger connection close
542+
// but the sendLoop does.
540543
s.shutdownLock.Lock()
541544
if s.shutdownErr == nil {
542545
s.conn.Close()
@@ -815,10 +818,7 @@ func (s *Session) handleGoAway(hdr header) error {
815818
code := hdr.Length()
816819
switch code {
817820
case goAwayNormal:
818-
atomic.SwapInt32(&s.remoteGoAwayNormal, 1)
819-
// Don't close connection on normal go away. Let the existing streams
820-
// complete gracefully.
821-
return nil
821+
return ErrRemoteGoAway
822822
case goAwayProtoErr:
823823
s.logger.Printf("[ERR] yamux: received protocol error go away")
824824
case goAwayInternalErr:

session_test.go

+17-21
Original file line numberDiff line numberDiff line change
@@ -648,15 +648,16 @@ func TestGoAway(t *testing.T) {
648648

649649
for i := 0; i < 100; i++ {
650650
s, err := client.Open(context.Background())
651-
switch err {
652-
case nil:
651+
if err == nil {
653652
s.Close()
654-
case ErrRemoteGoAwayNormal:
653+
time.Sleep(50 * time.Millisecond)
654+
continue
655+
}
656+
if err != ErrRemoteGoAway {
657+
t.Fatalf("expected %s, got %s", ErrRemoteGoAway, err)
658+
} else {
655659
return
656-
default:
657-
t.Fatalf("err: %v", err)
658660
}
659-
time.Sleep(50 * time.Millisecond)
660661
}
661662
t.Fatalf("expected GoAway error")
662663
}
@@ -1578,7 +1579,7 @@ func TestStreamResetWithError(t *testing.T) {
15781579
defer server.Close()
15791580

15801581
wc := new(sync.WaitGroup)
1581-
wc.Add(2)
1582+
wc.Add(1)
15821583
go func() {
15831584
defer wc.Done()
15841585
stream, err := server.AcceptStream()
@@ -1589,7 +1590,7 @@ func TestStreamResetWithError(t *testing.T) {
15891590
se := &StreamError{}
15901591
_, err = io.ReadAll(stream)
15911592
if !errors.As(err, &se) {
1592-
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
1593+
t.Errorf("expected StreamError, got type:%T, err: %s", err, err)
15931594
return
15941595
}
15951596
expected := &StreamError{Remote: true, ErrorCode: 42}
@@ -1601,24 +1602,19 @@ func TestStreamResetWithError(t *testing.T) {
16011602
t.Error(err)
16021603
}
16031604

1604-
go func() {
1605-
defer wc.Done()
1606-
1607-
se := &StreamError{}
1608-
_, err := io.ReadAll(stream)
1609-
if !errors.As(err, &se) {
1610-
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
1611-
return
1612-
}
1613-
expected := &StreamError{Remote: false, ErrorCode: 42}
1614-
assert.Equal(t, se, expected)
1615-
}()
1616-
16171605
time.Sleep(1 * time.Second)
16181606
err = stream.ResetWithError(42)
16191607
if err != nil {
16201608
t.Fatal(err)
16211609
}
1610+
se := &StreamError{}
1611+
_, err = io.ReadAll(stream)
1612+
if !errors.As(err, &se) {
1613+
t.Errorf("expected StreamError, got type:%T, err: %s", err, err)
1614+
return
1615+
}
1616+
expected := &StreamError{Remote: false, ErrorCode: 42}
1617+
assert.Equal(t, se, expected)
16221618
wc.Wait()
16231619
}
16241620

stream.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ func (s *Stream) cleanup() {
395395

396396
// processFlags is used to update the state of the stream
397397
// based on set flags, if any. Lock must be held
398-
func (s *Stream) processFlags(flags uint16, hdr header) {
398+
func (s *Stream) processFlags(hdr header, flags uint16) {
399399
// Close the stream without holding the state lock
400400
var closeStream bool
401401
defer func() {
@@ -459,15 +459,15 @@ func (s *Stream) notifyWaiting() {
459459

460460
// incrSendWindow updates the size of our send window
461461
func (s *Stream) incrSendWindow(hdr header, flags uint16) {
462-
s.processFlags(flags, hdr)
462+
s.processFlags(hdr, flags)
463463
// Increase window, unblock a sender
464464
atomic.AddUint32(&s.sendWindow, hdr.Length())
465465
asyncNotify(s.sendNotifyCh)
466466
}
467467

468468
// readData is used to handle a data frame
469469
func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
470-
s.processFlags(flags, hdr)
470+
s.processFlags(hdr, flags)
471471

472472
// Check that our recv window is not exceeded
473473
length := hdr.Length()

0 commit comments

Comments
 (0)