Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add receive chunk tracker for better received chunk handling #319

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ type Association struct {
myMaxNumInboundStreams uint16
myMaxNumOutboundStreams uint16
myCookie *paramStateCookie
payloadQueue *payloadQueue
receivedChunkTracker *receivedChunkTracker
inflightQueue *payloadQueue
pendingQueue *pendingQueue
controlQueue *controlQueue
Expand Down Expand Up @@ -329,7 +329,7 @@ func createAssociation(config Config) *Association {
myMaxNumOutboundStreams: math.MaxUint16,
myMaxNumInboundStreams: math.MaxUint16,

payloadQueue: newPayloadQueue(),
receivedChunkTracker: newReceivedChunkTracker(),
inflightQueue: newPayloadQueue(),
pendingQueue: newPendingQueue(),
controlQueue: newControlQueue(),
Expand Down Expand Up @@ -1406,7 +1406,7 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet {
a.name, d.tsn, d.immediateSack, len(d.userData))
a.stats.incDATAs()

canPush := a.payloadQueue.canPush(d, a.peerLastTSN, a.getMaxTSNOffset())
canPush := a.receivedChunkTracker.canPush(d, a.peerLastTSN, a.getMaxTSNOffset())
if canPush {
s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown)
if s == nil {
Expand All @@ -1418,14 +1418,14 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet {

if a.getMyReceiverWindowCredit() > 0 {
// Pass the new chunk to stream level as soon as it arrives
a.payloadQueue.push(d, a.peerLastTSN)
a.receivedChunkTracker.push(d.tsn, a.peerLastTSN)
s.handleData(d)
} else {
// Receive buffer is full
lastTSN, ok := a.payloadQueue.getLastTSNReceived()
lastTSN, ok := a.receivedChunkTracker.getLastTSNReceived()
if ok && sna32LT(d.tsn, lastTSN) {
a.log.Debugf("[%s] receive buffer full, but accepted as this is a missing chunk with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber)
a.payloadQueue.push(d, a.peerLastTSN)
a.receivedChunkTracker.push(d.tsn, a.peerLastTSN)
s.handleData(d)
} else {
a.log.Debugf("[%s] receive buffer full. dropping DATA with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber)
Expand All @@ -1449,7 +1449,7 @@ func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool)
// Meaning, if peerLastTSN+1 points to a chunk that is received,
// advance peerLastTSN until peerLastTSN+1 points to unreceived chunk.
for {
if _, popOk := a.payloadQueue.pop(a.peerLastTSN + 1); !popOk {
if popOk := a.receivedChunkTracker.pop(a.peerLastTSN + 1); !popOk {
break
}
a.peerLastTSN++
Expand All @@ -1463,9 +1463,9 @@ func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool)
}
}

hasPacketLoss := (a.payloadQueue.size() > 0)
hasPacketLoss := (a.receivedChunkTracker.size() > 0)
if hasPacketLoss {
a.log.Tracef("[%s] packetloss: %s", a.name, a.payloadQueue.getGapAckBlocksString(a.peerLastTSN))
a.log.Tracef("[%s] packetloss: %s", a.name, a.receivedChunkTracker.getGapAckBlocksString(a.peerLastTSN))
}

if (a.ackState != ackStateImmediate && !sackImmediately && !hasPacketLoss && a.ackMode == ackModeNormal) || a.ackMode == ackModeAlwaysDelay {
Expand Down Expand Up @@ -2084,7 +2084,7 @@ func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet {

// Advance peerLastTSN
for sna32LT(a.peerLastTSN, c.newCumulativeTSN) {
a.payloadQueue.pop(a.peerLastTSN + 1) // may not exist
a.receivedChunkTracker.pop(a.peerLastTSN + 1) // may not exist
a.peerLastTSN++
}

Expand Down Expand Up @@ -2427,8 +2427,8 @@ func (a *Association) createSelectiveAckChunk() *chunkSelectiveAck {
sack := &chunkSelectiveAck{}
sack.cumulativeTSNAck = a.peerLastTSN
sack.advertisedReceiverWindowCredit = a.getMyReceiverWindowCredit()
sack.duplicateTSN = a.payloadQueue.popDuplicates()
sack.gapAckBlocks = a.payloadQueue.getGapAckBlocks(a.peerLastTSN)
sack.duplicateTSN = a.receivedChunkTracker.popDuplicates()
sack.gapAckBlocks = a.receivedChunkTracker.getGapAckBlocks(a.peerLastTSN)
return sack
}

Expand Down
108 changes: 108 additions & 0 deletions association_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package sctp

import (
"io"
"net"
"testing"

"github.com/pion/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type netConnWrapper struct {
net.PacketConn
remoteAddr net.Addr
}

func (c *netConnWrapper) Read(b []byte) (int, error) {
n, _, err := c.PacketConn.ReadFrom(b)
return n, err
}

func (c *netConnWrapper) RemoteAddr() net.Addr {
return c.remoteAddr
}

func (c *netConnWrapper) Write(b []byte) (n int, err error) {
return c.PacketConn.WriteTo(b, c.remoteAddr)
}

var _ net.Conn = &netConnWrapper{}

func newNetConnPair(p1 net.PacketConn, p2 net.PacketConn) (net.Conn, net.Conn) {
return &netConnWrapper{
PacketConn: p1,
remoteAddr: p2.LocalAddr(),
},
&netConnWrapper{
PacketConn: p2,
remoteAddr: p1.LocalAddr(),
}
}

func BenchmarkSCTPThroughput(b *testing.B) {
b.ReportAllocs()
p1, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(b, err)
p2, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(b, err)

c1, c2 := newNetConnPair(p1, p2)
var server *Association
done := make(chan bool)
go func() {
var err error
server, err = Server(Config{
Name: "server",
NetConn: c1,
MaxReceiveBufferSize: 1024 * 1024,
LoggerFactory: &logging.DefaultLoggerFactory{},
})
require.NoError(b, err)
done <- true
}()

var client *Association
go func() {
var err error
client, err = Client(Config{
Name: "client",
NetConn: c2,
MaxReceiveBufferSize: 1024 * 1024,
LoggerFactory: &logging.DefaultLoggerFactory{},
})
require.NoError(b, err)
done <- true
}()
<-done
<-done
serverBuf := make([]byte, 16*(1<<10))
clientBuf := make([]byte, 16*(1<<10))
for i := 0; i < b.N; i++ {
s, err := client.OpenStream(uint16(i), PayloadTypeWebRTCBinary)
require.NoError(b, err)
go func() {
s, err := server.AcceptStream()
assert.NoError(b, err)
for {
_, err := s.Read(serverBuf)
if err != nil {
if err == io.EOF {
s.Close()

Check failure on line 92 in association_bench_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

Error return value of `s.Close` is not checked (errcheck)
break
} else {

Check warning on line 94 in association_bench_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

superfluous-else: if block ends with a break statement, so drop this else and outdent its block (revive)
b.Error("invalid err", err)
}
}
}
done <- true
}()
for i := 0; i < 1000; i++ {
_, err := s.Write(clientBuf)
require.NoError(b, err)
}
s.Close()

Check failure on line 105 in association_bench_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

Error return value of `s.Close` is not checked (errcheck)
<-done
}
}
18 changes: 2 additions & 16 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1311,14 +1311,7 @@ func TestHandleForwardTSN(t *testing.T) {
prevTSN := a.peerLastTSN

// this chunk is blocked by the missing chunk at tsn=1
a.payloadQueue.push(&chunkPayloadData{
beginningFragment: true,
endingFragment: true,
tsn: a.peerLastTSN + 2,
streamIdentifier: 0,
streamSequenceNumber: 1,
userData: []byte("ABC"),
}, a.peerLastTSN)
a.receivedChunkTracker.push(a.peerLastTSN+2, a.peerLastTSN)

fwdtsn := &chunkForwardTSN{
newCumulativeTSN: a.peerLastTSN + 1,
Expand Down Expand Up @@ -1348,14 +1341,7 @@ func TestHandleForwardTSN(t *testing.T) {
prevTSN := a.peerLastTSN

// this chunk is blocked by the missing chunk at tsn=1
a.payloadQueue.push(&chunkPayloadData{
beginningFragment: true,
endingFragment: true,
tsn: a.peerLastTSN + 3,
streamIdentifier: 0,
streamSequenceNumber: 1,
userData: []byte("ABC"),
}, a.peerLastTSN)
a.receivedChunkTracker.push(a.peerLastTSN+3, a.peerLastTSN)

fwdtsn := &chunkForwardTSN{
newCumulativeTSN: a.peerLastTSN + 1,
Expand Down
8 changes: 0 additions & 8 deletions payload_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,6 @@
})
}

func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32, maxTSNOffset uint32) bool {
_, ok := q.chunkMap[p.tsn]
if ok || sna32LTE(p.tsn, cumulativeTSN) || sna32GTE(p.tsn, cumulativeTSN+maxTSNOffset) {
return false
}
return true
}

func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) {
q.chunkMap[p.tsn] = p
q.nBytes += len(p.userData)
Expand All @@ -53,7 +45,7 @@
// push pushes a payload data. If the payload data is already in our queue or
// older than our cumulativeTSN marker, it will be recored as duplications,
// which can later be retrieved using popDuplicates.
func (q *payloadQueue) push(p *chunkPayloadData, cumulativeTSN uint32) bool {

Check failure on line 48 in payload_queue.go

View workflow job for this annotation

GitHub Actions / lint / Go

`(*payloadQueue).push` - `cumulativeTSN` always receives `0` (unparam)
_, ok := q.chunkMap[p.tsn]
if ok || sna32LTE(p.tsn, cumulativeTSN) {
// Found the packet, log in dups
Expand Down Expand Up @@ -90,7 +82,7 @@
}

// popDuplicates returns an array of TSN values that were found duplicate.
func (q *payloadQueue) popDuplicates() []uint32 {

Check failure on line 85 in payload_queue.go

View workflow job for this annotation

GitHub Actions / lint / Go

func `(*payloadQueue).popDuplicates` is unused (unused)
dups := q.dupTSN
q.dupTSN = []uint32{}
return dups
Expand Down Expand Up @@ -132,7 +124,7 @@
return gapAckBlocks
}

func (q *payloadQueue) getGapAckBlocksString(cumulativeTSN uint32) string {

Check failure on line 127 in payload_queue.go

View workflow job for this annotation

GitHub Actions / lint / Go

func `(*payloadQueue).getGapAckBlocksString` is unused (unused)
gapAckBlocks := q.getGapAckBlocks(cumulativeTSN)
str := fmt.Sprintf("cumTSN=%d", cumulativeTSN)
for _, b := range gapAckBlocks {
Expand Down
Loading
Loading