From dd87758c380cf8ab8650e2c77cd0b4505372f08a Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 1 Apr 2024 15:52:31 +0530 Subject: [PATCH 1/3] Add receive chunk tracker for better received chunk handling --- association.go | 24 ++--- association_test.go | 18 +--- payload_queue.go | 8 -- received_chunk_tracker.go | 158 +++++++++++++++++++++++++++++++++ received_chunk_tracker_test.go | 125 ++++++++++++++++++++++++++ 5 files changed, 297 insertions(+), 36 deletions(-) create mode 100644 received_chunk_tracker.go create mode 100644 received_chunk_tracker_test.go diff --git a/association.go b/association.go index d5722d3c..8f0c59eb 100644 --- a/association.go +++ b/association.go @@ -186,7 +186,7 @@ type Association struct { myMaxNumInboundStreams uint16 myMaxNumOutboundStreams uint16 myCookie *paramStateCookie - payloadQueue *payloadQueue + receivedChunkTracker *receivedChunkTracker inflightQueue *payloadQueue pendingQueue *pendingQueue controlQueue *controlQueue @@ -329,7 +329,7 @@ func createAssociation(config Config) *Association { myMaxNumOutboundStreams: math.MaxUint16, myMaxNumInboundStreams: math.MaxUint16, - payloadQueue: newPayloadQueue(), + receivedChunkTracker: newReceivedChunkTracker(), inflightQueue: newPayloadQueue(), pendingQueue: newPendingQueue(), controlQueue: newControlQueue(), @@ -1406,7 +1406,7 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet { a.name, d.tsn, d.immediateSack, len(d.userData)) a.stats.incDATAs() - canPush := a.payloadQueue.canPush(d, a.peerLastTSN, a.getMaxTSNOffset()) + canPush := a.receivedChunkTracker.canPush(d, a.peerLastTSN, a.getMaxTSNOffset()) if canPush { s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown) if s == nil { @@ -1418,14 +1418,14 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet { if a.getMyReceiverWindowCredit() > 0 { // Pass the new chunk to stream level as soon as it arrives - a.payloadQueue.push(d, a.peerLastTSN) + a.receivedChunkTracker.push(d.tsn, a.peerLastTSN) s.handleData(d) } else { // Receive buffer is full - lastTSN, ok := a.payloadQueue.getLastTSNReceived() + lastTSN, ok := a.receivedChunkTracker.getLastTSNReceived() if ok && sna32LT(d.tsn, lastTSN) { a.log.Debugf("[%s] receive buffer full, but accepted as this is a missing chunk with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber) - a.payloadQueue.push(d, a.peerLastTSN) + a.receivedChunkTracker.push(d.tsn, a.peerLastTSN) s.handleData(d) } else { a.log.Debugf("[%s] receive buffer full. dropping DATA with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber) @@ -1449,7 +1449,7 @@ func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool) // Meaning, if peerLastTSN+1 points to a chunk that is received, // advance peerLastTSN until peerLastTSN+1 points to unreceived chunk. for { - if _, popOk := a.payloadQueue.pop(a.peerLastTSN + 1); !popOk { + if popOk := a.receivedChunkTracker.pop(a.peerLastTSN + 1); !popOk { break } a.peerLastTSN++ @@ -1463,9 +1463,9 @@ func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool) } } - hasPacketLoss := (a.payloadQueue.size() > 0) + hasPacketLoss := (a.receivedChunkTracker.size() > 0) if hasPacketLoss { - a.log.Tracef("[%s] packetloss: %s", a.name, a.payloadQueue.getGapAckBlocksString(a.peerLastTSN)) + a.log.Tracef("[%s] packetloss: %s", a.name, a.receivedChunkTracker.getGapAckBlocksString(a.peerLastTSN)) } if (a.ackState != ackStateImmediate && !sackImmediately && !hasPacketLoss && a.ackMode == ackModeNormal) || a.ackMode == ackModeAlwaysDelay { @@ -2084,7 +2084,7 @@ func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet { // Advance peerLastTSN for sna32LT(a.peerLastTSN, c.newCumulativeTSN) { - a.payloadQueue.pop(a.peerLastTSN + 1) // may not exist + a.receivedChunkTracker.pop(a.peerLastTSN + 1) // may not exist a.peerLastTSN++ } @@ -2427,8 +2427,8 @@ func (a *Association) createSelectiveAckChunk() *chunkSelectiveAck { sack := &chunkSelectiveAck{} sack.cumulativeTSNAck = a.peerLastTSN sack.advertisedReceiverWindowCredit = a.getMyReceiverWindowCredit() - sack.duplicateTSN = a.payloadQueue.popDuplicates() - sack.gapAckBlocks = a.payloadQueue.getGapAckBlocks(a.peerLastTSN) + sack.duplicateTSN = a.receivedChunkTracker.popDuplicates() + sack.gapAckBlocks = a.receivedChunkTracker.getGapAckBlocks(a.peerLastTSN) return sack } diff --git a/association_test.go b/association_test.go index bfb8d76e..dea1c20c 100644 --- a/association_test.go +++ b/association_test.go @@ -1311,14 +1311,7 @@ func TestHandleForwardTSN(t *testing.T) { prevTSN := a.peerLastTSN // this chunk is blocked by the missing chunk at tsn=1 - a.payloadQueue.push(&chunkPayloadData{ - beginningFragment: true, - endingFragment: true, - tsn: a.peerLastTSN + 2, - streamIdentifier: 0, - streamSequenceNumber: 1, - userData: []byte("ABC"), - }, a.peerLastTSN) + a.receivedChunkTracker.push(a.peerLastTSN+2, a.peerLastTSN) fwdtsn := &chunkForwardTSN{ newCumulativeTSN: a.peerLastTSN + 1, @@ -1348,14 +1341,7 @@ func TestHandleForwardTSN(t *testing.T) { prevTSN := a.peerLastTSN // this chunk is blocked by the missing chunk at tsn=1 - a.payloadQueue.push(&chunkPayloadData{ - beginningFragment: true, - endingFragment: true, - tsn: a.peerLastTSN + 3, - streamIdentifier: 0, - streamSequenceNumber: 1, - userData: []byte("ABC"), - }, a.peerLastTSN) + a.receivedChunkTracker.push(a.peerLastTSN+3, a.peerLastTSN) fwdtsn := &chunkForwardTSN{ newCumulativeTSN: a.peerLastTSN + 1, diff --git a/payload_queue.go b/payload_queue.go index a0b1b26f..3c7cdb28 100644 --- a/payload_queue.go +++ b/payload_queue.go @@ -36,14 +36,6 @@ func (q *payloadQueue) updateSortedKeys() { }) } -func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32, maxTSNOffset uint32) bool { - _, ok := q.chunkMap[p.tsn] - if ok || sna32LTE(p.tsn, cumulativeTSN) || sna32GTE(p.tsn, cumulativeTSN+maxTSNOffset) { - return false - } - return true -} - func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) { q.chunkMap[p.tsn] = p q.nBytes += len(p.userData) diff --git a/received_chunk_tracker.go b/received_chunk_tracker.go new file mode 100644 index 00000000..ff5db270 --- /dev/null +++ b/received_chunk_tracker.go @@ -0,0 +1,158 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "fmt" + "strings" +) + +// receivedChunkTracker tracks received chunks for maintaining ACK ranges +type receivedChunkTracker struct { + tsns map[uint32]struct{} + dupTSN []uint32 + ranges []ackRange +} + +// ackRange is a contiguous range of chunks that we have received +type ackRange struct { + start uint32 + end uint32 +} + +func newReceivedChunkTracker() *receivedChunkTracker { + return &receivedChunkTracker{tsns: make(map[uint32]struct{})} +} + +func (q *receivedChunkTracker) canPush(p *chunkPayloadData, cumulativeTSN uint32, maxTSNOffset uint32) bool { + _, ok := q.tsns[p.tsn] + if ok || sna32LTE(p.tsn, cumulativeTSN) || sna32GTE(p.tsn, cumulativeTSN+maxTSNOffset) { + return false + } + return true +} + +// push pushes a tsn for tracking. If the tsn is already tracked or +// older than our cumulativeTSN marker, it will be recorded as a duplicate, +// which can later be retrieved using popDuplicates. +func (q *receivedChunkTracker) push(tsn uint32, cumulativeTSN uint32) bool { + _, ok := q.tsns[tsn] + if ok || sna32LTE(tsn, cumulativeTSN) { + // Found the packet, log in dups + q.dupTSN = append(q.dupTSN, tsn) + return false + } + q.tsns[tsn] = struct{}{} + + insert := true + var pos int + for pos = len(q.ranges) - 1; pos >= 0; pos-- { + if tsn == q.ranges[pos].end+1 { + q.ranges[pos].end++ + insert = false + break + } + if tsn == q.ranges[pos].start-1 { + q.ranges[pos].start-- + insert = false + break + } + if tsn > q.ranges[pos].end { + break + } + } + if insert { + // pos is at the element just before the insertion point + // increment and make it equal to the insertion point + pos++ + q.ranges = append(q.ranges, ackRange{}) + copy(q.ranges[pos+1:], q.ranges[pos:]) + q.ranges[pos] = ackRange{start: tsn, end: tsn} + } else { + // extended element at pos, check if we can merge it with adjacent elements + if pos-1 >= 0 { + if q.ranges[pos-1].end+1 >= q.ranges[pos].start { + q.ranges[pos-1] = ackRange{ + start: q.ranges[pos-1].start, + end: q.ranges[pos].end, + } + copy(q.ranges[pos:], q.ranges[pos+1:]) + q.ranges = q.ranges[:len(q.ranges)-1] + // We have merged pos and pos-1 in to pos-1, update pos to reflect that. + // Not updating this won't be an error but it's nice to maintain the invariant + pos-- + } + } + if pos+1 < len(q.ranges) { + if q.ranges[pos].end+1 >= q.ranges[pos+1].start { + q.ranges[pos+1] = ackRange{ + start: q.ranges[pos].start, + end: q.ranges[pos+1].end, + } + copy(q.ranges[pos:], q.ranges[pos+1:]) + q.ranges = q.ranges[:len(q.ranges)-1] + } + } + } + return true +} + +// pop pops only if the oldest chunk's TSN matches the given TSN. +func (q *receivedChunkTracker) pop(tsn uint32) bool { + if len(q.ranges) == 0 || q.ranges[0].start != tsn { + return false + } + q.ranges[0].start++ + if q.ranges[0].start > q.ranges[0].end { + q.ranges = q.ranges[1:] + } + delete(q.tsns, tsn) + return true +} + +// popDuplicates returns an array of TSN values that were found duplicate. +func (q *receivedChunkTracker) popDuplicates() []uint32 { + dups := q.dupTSN + q.dupTSN = []uint32{} + return dups +} + +// receivedPacketTracker getGapACKBlocks returns gapAckBlocks after the cummulative TSN +func (q *receivedChunkTracker) getGapAckBlocks(cumulativeTSN uint32) []gapAckBlock { + gapAckBlocks := make([]gapAckBlock, 0, len(q.ranges)) + for _, ar := range q.ranges { + if ar.end > cumulativeTSN { + st := ar.start + if st < cumulativeTSN { + st = cumulativeTSN + 1 + } + gapAckBlocks = append(gapAckBlocks, gapAckBlock{ + start: uint16(st - cumulativeTSN), + end: uint16(ar.end - cumulativeTSN), + }) + } + } + return gapAckBlocks +} + +func (q *receivedChunkTracker) getGapAckBlocksString(cumulativeTSN uint32) string { + gapAckBlocks := q.getGapAckBlocks(cumulativeTSN) + sb := strings.Builder{} + sb.WriteString(fmt.Sprintf("cumTSN=%d", cumulativeTSN)) + for _, b := range gapAckBlocks { + sb.WriteString(fmt.Sprintf(",%d-%d", b.start, b.end)) + } + return sb.String() +} + +func (q *receivedChunkTracker) getLastTSNReceived() (uint32, bool) { + if len(q.ranges) == 0 { + return 0, false + } + return q.ranges[len(q.ranges)-1].end, true +} + +func (q *receivedChunkTracker) size() int { + return len(q.tsns) +} diff --git a/received_chunk_tracker_test.go b/received_chunk_tracker_test.go new file mode 100644 index 00000000..258a37ef --- /dev/null +++ b/received_chunk_tracker_test.go @@ -0,0 +1,125 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "fmt" + "math/rand" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestReceivedPacketTrackerPushPop(t *testing.T) { + q := newReceivedChunkTracker() + for i := uint32(1); i < 100; i++ { + q.push(i, 0) + } + // leave a gap at position 100 + for i := uint32(101); i < 200; i++ { + q.push(i, 0) + } + for i := uint32(2); i < 200; i++ { + require.False(t, q.pop(i)) // all pop will fail till we pop the first tsn + } + for i := uint32(1); i < 100; i++ { + require.True(t, q.pop(i)) + } + // 101 is the smallest value now + for i := uint32(102); i < 200; i++ { + require.False(t, q.pop(i)) + } + q.push(100, 99) + for i := uint32(100); i < 200; i++ { + require.True(t, q.pop(i)) + } + + // q is empty now + require.Equal(t, q.size(), 0) + for i := uint32(0); i < 200; i++ { + require.False(t, q.pop(i)) + } +} + +func TestReceivedPacketTrackerGapACKBlocksStress(t *testing.T) { + testChunks := func(chunks []uint32, st uint32) { + if len(chunks) == 0 { + return + } + expected := make([]gapAckBlock, 0, len(chunks)) + cr := ackRange{start: chunks[0], end: chunks[0]} + for i := 1; i < len(chunks); i++ { + if cr.end+1 != chunks[i] { + expected = append(expected, gapAckBlock{ + start: uint16(cr.start - st), + end: uint16(cr.end - st), + }) + cr = ackRange{start: chunks[i], end: chunks[i]} + } else { + cr.end++ + } + } + expected = append(expected, gapAckBlock{ + start: uint16(cr.start - st), + end: uint16(cr.end - st), + }) + + q := newReceivedChunkTracker() + rand.Shuffle(len(chunks), func(i, j int) { + chunks[i], chunks[j] = chunks[j], chunks[i] + }) + for _, t := range chunks { + q.push(t, 0) + } + res := q.getGapAckBlocks(0) + require.Equal(t, expected, res, chunks) + } + chunks := make([]uint32, 0, 10) + for i := 1; i < (1 << 10); i++ { + for j := 0; j < 10; j++ { + if i&(1< Date: Tue, 2 Apr 2024 15:43:49 +0530 Subject: [PATCH 2/3] add benchmark --- association_bench_test.go | 108 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 association_bench_test.go diff --git a/association_bench_test.go b/association_bench_test.go new file mode 100644 index 00000000..87d938ef --- /dev/null +++ b/association_bench_test.go @@ -0,0 +1,108 @@ +package sctp + +import ( + "io" + "net" + "testing" + + "github.com/pion/logging" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type netConnWrapper struct { + net.PacketConn + remoteAddr net.Addr +} + +func (c *netConnWrapper) Read(b []byte) (int, error) { + n, _, err := c.PacketConn.ReadFrom(b) + return n, err +} + +func (c *netConnWrapper) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *netConnWrapper) Write(b []byte) (n int, err error) { + return c.PacketConn.WriteTo(b, c.remoteAddr) +} + +var _ net.Conn = &netConnWrapper{} + +func newNetConnPair(p1 net.PacketConn, p2 net.PacketConn) (net.Conn, net.Conn) { + return &netConnWrapper{ + PacketConn: p1, + remoteAddr: p2.LocalAddr(), + }, + &netConnWrapper{ + PacketConn: p2, + remoteAddr: p1.LocalAddr(), + } +} + +func BenchmarkSCTPThroughput(b *testing.B) { + b.ReportAllocs() + p1, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(b, err) + p2, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(b, err) + + c1, c2 := newNetConnPair(p1, p2) + var server *Association + done := make(chan bool) + go func() { + var err error + server, err = Server(Config{ + Name: "server", + NetConn: c1, + MaxReceiveBufferSize: 1024 * 1024, + LoggerFactory: &logging.DefaultLoggerFactory{}, + }) + require.NoError(b, err) + done <- true + }() + + var client *Association + go func() { + var err error + client, err = Client(Config{ + Name: "client", + NetConn: c2, + MaxReceiveBufferSize: 1024 * 1024, + LoggerFactory: &logging.DefaultLoggerFactory{}, + }) + require.NoError(b, err) + done <- true + }() + <-done + <-done + serverBuf := make([]byte, 16*(1<<10)) + clientBuf := make([]byte, 16*(1<<10)) + for i := 0; i < b.N; i++ { + s, err := client.OpenStream(uint16(i), PayloadTypeWebRTCBinary) + require.NoError(b, err) + go func() { + s, err := server.AcceptStream() + assert.NoError(b, err) + for { + _, err := s.Read(serverBuf) + if err != nil { + if err == io.EOF { + s.Close() + break + } else { + b.Error("invalid err", err) + } + } + } + done <- true + }() + for i := 0; i < 1000; i++ { + _, err := s.Write(clientBuf) + require.NoError(b, err) + } + s.Close() + <-done + } +} From c57fa569552f90cf2962eacc4c04bce9e32bf190 Mon Sep 17 00:00:00 2001 From: sukun Date: Wed, 3 Apr 2024 17:15:09 +0530 Subject: [PATCH 3/3] fix >= check for SNA --- received_chunk_tracker.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/received_chunk_tracker.go b/received_chunk_tracker.go index ff5db270..e56a2831 100644 --- a/received_chunk_tracker.go +++ b/received_chunk_tracker.go @@ -10,7 +10,7 @@ import ( // receivedChunkTracker tracks received chunks for maintaining ACK ranges type receivedChunkTracker struct { - tsns map[uint32]struct{} + tsns map[uint32]struct{} dupTSN []uint32 ranges []ackRange } @@ -72,7 +72,7 @@ func (q *receivedChunkTracker) push(tsn uint32, cumulativeTSN uint32) bool { } else { // extended element at pos, check if we can merge it with adjacent elements if pos-1 >= 0 { - if q.ranges[pos-1].end+1 >= q.ranges[pos].start { + if q.ranges[pos-1].end+1 == q.ranges[pos].start { q.ranges[pos-1] = ackRange{ start: q.ranges[pos-1].start, end: q.ranges[pos].end, @@ -85,7 +85,7 @@ func (q *receivedChunkTracker) push(tsn uint32, cumulativeTSN uint32) bool { } } if pos+1 < len(q.ranges) { - if q.ranges[pos].end+1 >= q.ranges[pos+1].start { + if q.ranges[pos].end+1 == q.ranges[pos+1].start { q.ranges[pos+1] = ackRange{ start: q.ranges[pos].start, end: q.ranges[pos+1].end,