diff --git a/peerconnection.go b/peerconnection.go index 1f0a7d1591b..1ac253bb1a3 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -12,7 +12,6 @@ import ( "crypto/rand" "errors" "fmt" - "io" "strconv" "strings" "sync" @@ -1480,7 +1479,7 @@ func (pc *PeerConnection) startSCTP() { } } -func (pc *PeerConnection) handleUndeclaredSSRC(ssrc SSRC, remoteDescription *SessionDescription) (handled bool, err error) { +func (pc *PeerConnection) handleUndeclaredSSRC(ssrc SSRC, payloadType PayloadType, remoteDescription *SessionDescription) (handled bool, err error) { if len(remoteDescription.parsed.MediaDescriptions) != 1 { return false, nil } @@ -1534,7 +1533,7 @@ func (pc *PeerConnection) handleUndeclaredSSRC(ssrc SSRC, remoteDescription *Ses return true, nil } -func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) error { //nolint:gocognit +func (pc *PeerConnection) handleIncomingSSRC(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC, payloadType PayloadType) error { //nolint:gocognit remoteDescription := pc.RemoteDescription() if remoteDescription == nil { return errPeerConnRemoteDescriptionNil @@ -1553,7 +1552,7 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err } // If the remote SDP was only one media section the ssrc doesn't have to be explicitly declared - if handled, err := pc.handleUndeclaredSSRC(ssrc, remoteDescription); handled || err != nil { + if handled, err := pc.handleUndeclaredSSRC(ssrc, payloadType, remoteDescription); handled || err != nil { return err } @@ -1569,18 +1568,6 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err repairStreamIDExtensionID, _, _ := pc.api.mediaEngine.getHeaderExtensionID(RTPHeaderExtensionCapability{sdesRepairRTPStreamIDURI}) - b := make([]byte, pc.api.settingEngine.getReceiveMTU()) - - i, err := rtpStream.Read(b) - if err != nil { - return err - } - - if i < 4 { - return errRTPTooShort - } - - payloadType := PayloadType(b[1] & 0x7f) params, err := pc.api.mediaEngine.getRTPParametersByPayloadType(payloadType) if err != nil { return err @@ -1592,8 +1579,10 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err return err } + b := make([]byte, pc.api.settingEngine.getReceiveMTU()) var mid, rid, rsid string var paddingOnly bool + firstPacket := true for readCount := 0; readCount <= simulcastProbeCount; readCount++ { if mid == "" || (rid == "" && rsid == "") { // skip padding only packets for probing @@ -1601,7 +1590,13 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err readCount-- } - i, _, err := interceptor.Read(b, nil) + if !firstPacket { + // Consume the packet that we peeked last time + if _, err := readStream.Read([]byte{}); err != nil { + return err + } + } + i, err := readStream.Peek(b) if err != nil { return err } @@ -1610,6 +1605,7 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err return err } + firstPacket = false continue } @@ -1653,7 +1649,7 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { return } - stream, ssrc, err := srtpSession.AcceptStream() + stream, ssrc, payloadType, err := srtpSession.AcceptStreamWithPayloadType() if err != nil { pc.log.Warnf("Failed to accept RTP %v", err) return @@ -1674,12 +1670,12 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { continue } - go func(rtpStream io.Reader, ssrc SSRC) { - if err := pc.handleIncomingSSRC(rtpStream, ssrc); err != nil { + go func(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC, payloadType PayloadType) { + if err := pc.handleIncomingSSRC(rtpStream, ssrc, payloadType); err != nil { pc.log.Errorf(incomingUnhandledRTPSsrc, ssrc, err) } atomic.AddUint64(&simulcastRoutineCount, ^uint64(0)) - }(stream, SSRC(ssrc)) + }(stream, SSRC(ssrc), PayloadType(payloadType)) } }