From ad7ca4478fbe0cbc31a9a2c2075afa30ce912207 Mon Sep 17 00:00:00 2001 From: OrlandoCo Date: Thu, 7 Jan 2021 12:49:04 -0600 Subject: [PATCH] Fix(SFU): Prevent multiple receiver close calls on simulcast --- pkg/sfu/receiver.go | 5 +++-- pkg/sfu/router.go | 31 ++++++++++++++++--------------- pkg/sfu/subscriber.go | 12 +++++++----- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index 9bd5d9b22..65e69e505 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -34,7 +34,8 @@ type Receiver interface { // WebRTCReceiver receives a video track type WebRTCReceiver struct { sync.Mutex - rtcpMu sync.RWMutex + rtcpMu sync.RWMutex + closeOnce sync.Once peerID string trackID string @@ -215,7 +216,7 @@ func (w *WebRTCReceiver) writeRTP(layer int) { w.closeTracks(layer) w.nackWorker.Stop() if w.onCloseHandler != nil { - w.onCloseHandler() + w.closeOnce.Do(w.onCloseHandler) } }() for pkt := range w.buffers[layer].PacketChan() { diff --git a/pkg/sfu/router.go b/pkg/sfu/router.go index 0c9d99c85..bb62313ac 100644 --- a/pkg/sfu/router.go +++ b/pkg/sfu/router.go @@ -128,12 +128,19 @@ func (r *router) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.TrackRe } }) - recv := r.receivers[trackID] - if recv == nil { + recv, ok := r.receivers[trackID] + if !ok { recv = NewWebRTCReceiver(receiver, track, r.id) r.receivers[trackID] = recv recv.SetRTCPCh(r.rtcpCh) recv.OnCloseHandler(func() { + if r.config.WithStats { + if track.Kind() == webrtc.RTPCodecTypeVideo { + stats.VideoTracks.Dec() + } else { + stats.AudioTracks.Dec() + } + } r.deleteReceiver(trackID, uint32(track.SSRC())) }) publish = true @@ -217,11 +224,13 @@ func (r *router) addDownTrack(sub *Subscriber, recv Receiver) error { // nolint:scopelint downTrack.OnCloseHandler(func() { - if err := sub.pc.RemoveTrack(downTrack.transceiver.Sender()); err != nil { - log.Errorf("Error closing down track: %v", err) - } else { - sub.RemoveDownTrack(recv.StreamID(), downTrack) - sub.negotiate() + if sub.pc.ConnectionState() != webrtc.PeerConnectionStateClosed { + if err := sub.pc.RemoveTrack(downTrack.transceiver.Sender()); err != nil { + log.Errorf("Error closing down track: %v", err) + } else { + sub.RemoveDownTrack(recv.StreamID(), downTrack) + sub.negotiate() + } } }) @@ -236,14 +245,6 @@ func (r *router) addDownTrack(sub *Subscriber, recv Receiver) error { func (r *router) deleteReceiver(track string, ssrc uint32) { r.Lock() - if r.config.WithStats { - if r.receivers[track].Kind() == webrtc.RTPCodecTypeVideo { - stats.VideoTracks.Dec() - } else { - stats.AudioTracks.Dec() - } - } - delete(r.receivers, track) delete(r.stats, ssrc) r.Unlock() diff --git a/pkg/sfu/subscriber.go b/pkg/sfu/subscriber.go index 34dffd4b7..be0c349bc 100644 --- a/pkg/sfu/subscriber.go +++ b/pkg/sfu/subscriber.go @@ -135,10 +135,12 @@ func (s *Subscriber) RemoveDownTrack(streamID string, downTrack *DownTrack) { idx = i } } - dts[idx] = dts[len(dts)-1] - dts[len(dts)-1] = nil - dts = dts[:len(dts)-1] - s.tracks[streamID] = dts + if idx >= 0 { + dts[idx] = dts[len(dts)-1] + dts[len(dts)-1] = nil + dts = dts[:len(dts)-1] + s.tracks[streamID] = dts + } } } @@ -193,7 +195,7 @@ func (s *Subscriber) downTracksReports() { for { time.Sleep(5 * time.Second) - if s.pc.ConnectionState() == webrtc.ICETransportStateClosed { + if s.pc.ConnectionState() == webrtc.PeerConnectionStateClosed { return }