Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Association and Stream closure #236

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ type Association struct {
storedCookieEcho *chunkCookieEcho

streams map[uint16]*Stream
streamsClosedErr error
acceptCh chan *Stream
readLoopCloseCh chan struct{}
awakeWriteLoopCh chan struct{}
Expand Down Expand Up @@ -513,6 +514,7 @@ func (a *Association) readLoop() {
for _, s := range a.streams {
a.unregisterStream(s, closeErr)
}
a.streamsClosedErr = closeErr
a.lock.Unlock()
close(a.acceptCh)
close(a.readLoopCloseCh)
Expand Down Expand Up @@ -552,9 +554,13 @@ func (a *Association) readLoop() {

func (a *Association) writeLoop() {
a.log.Debugf("[%s] writeLoop entered", a.name)
defer a.log.Debugf("[%s] writeLoop exited", a.name)
defer func() {
if err := a.close(); err != nil {
a.log.Warnf("[%s] failed to close association: %v", a.name, err)
}
a.log.Debugf("[%s] writeLoop exited", a.name)
}()

loop:
for {
rawPackets, ok := a.gatherOutbound()

Expand All @@ -565,28 +571,21 @@ loop:
a.log.Warnf("[%s] failed to write packets on netConn: %v", a.name, err)
}
a.log.Debugf("[%s] writeLoop ended", a.name)
break loop
return
}
atomic.AddUint64(&a.bytesSent, uint64(len(raw)))
}

if !ok {
if err := a.close(); err != nil {
a.log.Warnf("[%s] failed to close association: %v", a.name, err)
}

return
}

select {
case <-a.awakeWriteLoopCh:
case <-a.closeWriteLoopCh:
break loop
return
}
}

a.setState(closed)
a.closeAllTimers()
}

func (a *Association) awakeWriteLoop() {
Expand Down Expand Up @@ -1349,6 +1348,10 @@ func (a *Association) OpenStream(streamIdentifier uint16, defaultPayloadType Pay
a.lock.Lock()
defer a.lock.Unlock()

if a.streamsClosedErr != nil {
return nil, a.streamsClosedErr
}

return a.getOrCreateStream(streamIdentifier, false, defaultPayloadType), nil
}

Expand All @@ -1363,6 +1366,11 @@ func (a *Association) AcceptStream() (*Stream, error) {

// createStream creates a stream. The caller should hold the lock and check no stream exists for this id.
func (a *Association) createStream(streamIdentifier uint16, accept bool) *Stream {
if a.streamsClosedErr != nil {
a.log.Debugf("[%s] dropped a new stream (streamsClosedErr: %s)", a.name, a.streamsClosedErr)
return nil
}

s := &Stream{
association: a,
streamIdentifier: streamIdentifier,
Expand Down
71 changes: 71 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2168,6 +2168,7 @@ func TestAssocReset(t *testing.T) {
_, _, err = s0.ReadSCTP(buf)
assert.Equal(t, io.EOF, err, "should be EOF")
doneCh <- err
return
}
}()

Expand Down Expand Up @@ -2278,6 +2279,11 @@ func (c *fakeEchoConn) Write(b []byte) (int, error) {
func (c *fakeEchoConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
select {
case <-c.closed:
return c.errClose
default:
}
close(c.echo)
close(c.closed)
return c.errClose
Expand Down Expand Up @@ -2836,4 +2842,69 @@ func TestAssociation_Abort(t *testing.T) {
i, err = s21.Read(buf)
assert.Equal(t, i, 0, "expected no data read")
assert.Error(t, err, "User Initiated Abort: 1234", "expected abort reason")

// Ensure a1 has closed down as well (avoid goroutine leak).
select {
case <-a1.readLoopCloseCh:
case <-time.After(1 * time.Second):
assert.Fail(t, "timed out waiting for a1 read loop to close")
}

time.Sleep(time.Millisecond) // give readLoop a ms to completely exit.
}

func TestAssociation_OpenStreamAfterCloseMustNotHang(t *testing.T) {
runtime.GC()
n0 := runtime.NumGoroutine()

defer func() {
runtime.GC()
assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked")
}()

a1, a2 := createAssocs(t)

s11, err := a1.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

startOpenStream := make(chan struct{})
go func() {
_ = a2.close() // trigger close of read loop.
close(startOpenStream)
}()

done := make(chan struct{})
go func() {
defer close(done)

<-startOpenStream
s21, err := a2.OpenStream(1, PayloadTypeWebRTCString)
if err == nil {
// If stream opened, ensure ReadSCTP doesn't hang.
_, _, err = s21.ReadSCTP(make([]byte, 1))
assert.Error(t, err, "read did not exit with error")
}
}()

timeout := time.After(2 * time.Second)

select {
case <-done:
case <-timeout:
assert.Fail(t, "timed out waiting for a2.OpenStream test goroutine")
}

_ = s11.Close()
select {
case <-a1.readLoopCloseCh:
case <-timeout:
assert.Fail(t, "timed out waiting for a1 read loop to close")
}
select {
case <-a2.readLoopCloseCh:
case <-timeout:
assert.Fail(t, "timed out waiting for a2 read loop to close")
}

time.Sleep(time.Millisecond) // give readLoop a ms to completely exit.
}