Skip to content

Commit

Permalink
Implement DTLS/SRTP/SCTP restart
Browse files Browse the repository at this point in the history
Fixes #1636
  • Loading branch information
Antonito committed Jun 29, 2021
1 parent 7948437 commit 13df7ae
Show file tree
Hide file tree
Showing 5 changed files with 440 additions and 8 deletions.
20 changes: 16 additions & 4 deletions datachannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (api *API) NewDataChannel(transport *SCTPTransport, params *DataChannelPara
return nil, err
}

err = d.open(transport)
err = d.open(transport, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -103,14 +103,14 @@ func (api *API) newDataChannel(params *DataChannelParameters, log logging.Levele
}

// open opens the datachannel over the sctp transport
func (d *DataChannel) open(sctpTransport *SCTPTransport) error {
func (d *DataChannel) open(sctpTransport *SCTPTransport, restart bool) error {
association := sctpTransport.association()
if association == nil {
return errSCTPNotEstablished
}

d.mu.Lock()
if d.sctpTransport != nil { // already open
if d.sctpTransport != nil && !restart { // already open & not restarting
d.mu.Unlock()
return nil
}
Expand Down Expand Up @@ -164,6 +164,11 @@ func (d *DataChannel) open(sctpTransport *SCTPTransport) error {
return err
}

// If restarting, the `Open` event should be triggered again, once.
if restart {
d.openHandlerOnce = sync.Once{}
}

// bufferedAmountLowThreshold and onBufferedAmountLow might be set earlier
dc.SetBufferedAmountLowThreshold(d.bufferedAmountLowThreshold)
dc.OnBufferedAmountLow(d.onBufferedAmountLow)
Expand Down Expand Up @@ -309,11 +314,18 @@ func (d *DataChannel) readLoop() {
n, isString, err := d.dataChannel.ReadDataChannel(buffer)
if err != nil {
rlBufPool.Put(buffer) // nolint:staticcheck

previousState := d.ReadyState()
d.setReadyState(DataChannelStateClosed)

if err != io.EOF {
d.onError(err)
}
d.onClose()

// https://www.w3.org/TR/webrtc/#announcing-a-data-channel-as-closed
if previousState != DataChannelStateClosed {
d.onClose()
}
return
}

Expand Down
25 changes: 25 additions & 0 deletions dtlstransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,31 @@ func (t *DTLSTransport) startSRTP() error {
return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
}

isAlreadyRunning := func() bool {
select {
case <-t.srtpReady:
return true
default:
return false
}
}()

if isAlreadyRunning {
if sess, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok {
if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil {
return updateErr
}
}

if sess, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok {
if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil {
return updateErr
}
}

return nil
}

srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
if err != nil {
return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
Expand Down
56 changes: 54 additions & 2 deletions peerconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,59 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {
pc.ops.Enqueue(func() {
pc.startRTP(true, &desc, currentTransceivers)
})
} else if pc.dtlsTransport.State() != DTLSTransportStateNew {
fingerprint, fingerprintHash, fErr := extractFingerprint(desc.parsed)
if fErr != nil {
return fErr
}

fingerPrintDidChange := true

for _, fp := range pc.dtlsTransport.remoteParameters.Fingerprints {
if fingerprint == fp.Value && fingerprintHash == fp.Algorithm {
fingerPrintDidChange = false
break
}
}

if fingerPrintDidChange {
pc.ops.Enqueue(func() {
// SCTP uses DTLS, so prevent any use, by locking, while
// DTLS is restarting.
pc.sctpTransport.lock.Lock()
defer pc.sctpTransport.lock.Unlock()

if dErr := pc.dtlsTransport.Stop(); dErr != nil {
pc.log.Warnf("Failed to stop DTLS: %s", dErr)
}

// libwebrtc switches the connection back to `new`.
pc.dtlsTransport.lock.Lock()
pc.dtlsTransport.onStateChange(DTLSTransportStateNew)
pc.dtlsTransport.lock.Unlock()

// Restart the dtls transport with updated fingerprints
err = pc.dtlsTransport.Start(DTLSParameters{
Role: dtlsRoleFromRemoteSDP(desc.parsed),
Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}},
})
pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State())
if err != nil {
pc.log.Warnf("Failed to restart DTLS: %s", err)
return
}

// If SCTP was enabled, restart it with the new DTLS transport.
if pc.sctpTransport.isStarted {
if dErr := pc.sctpTransport.restart(pc.dtlsTransport.conn); dErr != nil {
pc.log.Warnf("Failed to restart SCTP: %s", dErr)
return
}
}
})
}
}

return nil
}

Expand Down Expand Up @@ -1317,7 +1369,7 @@ func (pc *PeerConnection) startSCTP() {
var openedDCCount uint32
for _, d := range dataChannels {
if d.ReadyState() == DataChannelStateConnecting {
err := d.open(pc.sctpTransport)
err := d.open(pc.sctpTransport, false)
if err != nil {
pc.log.Warnf("failed to open data channel: %s", err)
continue
Expand Down Expand Up @@ -1775,7 +1827,7 @@ func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelIn

// If SCTP already connected open all the channels
if pc.sctpTransport.State() == SCTPTransportStateConnected {
if err = d.open(pc.sctpTransport); err != nil {
if err = d.open(pc.sctpTransport, false); err != nil {
return nil, err
}
}
Expand Down
Loading

0 comments on commit 13df7ae

Please sign in to comment.