Skip to content

Commit

Permalink
Add BindingRequestHandler
Browse files Browse the repository at this point in the history
Allow the user to perform custom processing for inbound STUN Binding
requests. This allows users to do some of the following

* Log incoming Binding Requests for debugging
* Implement draft-thatcher-ice-renomination
* Implement custom CandidatePair switching logic

Resolves pion/webrtc#2539
Resolves pion/webrtc#2585
Resolves #623
  • Loading branch information
Sean-Der committed May 2, 2024
1 parent 01f82e2 commit da3175f
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 5 deletions.
6 changes: 6 additions & 0 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ type Agent struct {
taskLoopDone chan struct{}
err atomicx.Error

// Callback that allows user to implement custom behavior
// for STUN Binding Requests
userBindingRequestHandler func(m *stun.Message, local, remote Candidate, pair *CandidatePair) bool

gatherCandidateCancel func()
gatherCandidateDone chan struct{}

Expand Down Expand Up @@ -322,6 +326,8 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
includeLoopback: config.IncludeLoopback,

disableActiveTCP: config.DisableActiveTCP,

userBindingRequestHandler: config.BindingRequestHandler,
}

if a.net == nil {
Expand Down
7 changes: 7 additions & 0 deletions agent_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ type AgentConfig struct {
// DisableActiveTCP can be used to disable Active TCP candidates. Otherwise when TCP is enabled
// Active TCP candidates will be created when a new passive TCP remote candidate is added.
DisableActiveTCP bool

// BindingRequestHandler allows applications to perform logic on incoming STUN Binding Requests
// This was implemented to allow users to
// * Log incoming Binding Requests for debugging
// * Implement draft-thatcher-ice-renomination
// * Implement custom CandidatePair switching logic
BindingRequestHandler func(m *stun.Message, local, remote Candidate, pair *CandidatePair) bool
}

// initWithDefaults populates an agent and falls back to defaults if fields are unset
Expand Down
20 changes: 15 additions & 5 deletions selection.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ func (s *controllingSelector) HandleBindingRequest(m *stun.Message, local, remot
s.nominatePair(p)
}
}

if s.agent.userBindingRequestHandler != nil {
if shouldSwitch := s.agent.userBindingRequestHandler(m, local, remote, p); shouldSwitch {
s.agent.setSelectedPair(p)
}
}
}

func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) {
Expand Down Expand Up @@ -242,23 +248,21 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot
}

func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote Candidate) {
useCandidate := m.Contains(stun.AttrUseCandidate)

p := s.agent.findPair(local, remote)
if p == nil {
p = s.agent.addPair(local, remote)
}

if useCandidate {
if m.Contains(stun.AttrUseCandidate) {
// https://tools.ietf.org/html/rfc8445#section-7.3.1.5

if p.state == CandidatePairStateSucceeded {
// If the state of this pair is Succeeded, it means that the check
// previously sent by this pair produced a successful response and
// generated a valid pair (Section 7.2.5.3.2). The agent sets the
// nominated flag value of the valid pair to true.
if selectedPair := s.agent.getSelectedPair(); selectedPair == nil ||
(selectedPair != p && selectedPair.priority() <= p.priority()) {
selectedPair := s.agent.getSelectedPair()
if selectedPair == nil || (selectedPair != p && selectedPair.priority() <= p.priority()) {
s.agent.setSelectedPair(p)
} else if selectedPair != p {
s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", p, selectedPair)
Expand All @@ -278,6 +282,12 @@ func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote

s.agent.sendBindingSuccess(m, local, remote)
s.PingCandidate(local, remote)

if s.agent.userBindingRequestHandler != nil {
if shouldSwitch := s.agent.userBindingRequestHandler(m, local, remote, p); shouldSwitch {
s.agent.setSelectedPair(p)
}
}
}

type liteSelector struct {
Expand Down
156 changes: 156 additions & 0 deletions selection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

//go:build !js
// +build !js

package ice

import (
"bytes"
"context"
"errors"
"io"
"net"
"sync/atomic"
"testing"
"time"

"github.com/pion/stun"
"github.com/pion/transport/v3/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func sendUntilDone(t *testing.T, writingConn, readingConn net.Conn, maxAttempts int) bool {
testMessage := []byte("Hello World")
testBuffer := make([]byte, len(testMessage))

readDone, readDoneCancel := context.WithCancel(context.Background())
go func() {
_, err := readingConn.Read(testBuffer)
if errors.Is(err, io.EOF) {
return
}

require.NoError(t, err)
require.True(t, bytes.Equal(testMessage, testBuffer))

readDoneCancel()
}()

attempts := 0
for {
select {
case <-time.After(5 * time.Millisecond):
if attempts > maxAttempts {
return false
}

_, err := writingConn.Write(testMessage)
require.NoError(t, err)
attempts++
case <-readDone.Done():
return true
}
}
}

func TestBindingRequestHandler(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop()

var switchToNewCandidatePair, controlledLoggingFired atomic.Value
oneHour := time.Hour
keepaliveInterval := time.Millisecond * 20

aNotifier, aConnected := onConnected()
bNotifier, bConnected := onConnected()
controllingAgent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
MulticastDNSMode: MulticastDNSModeDisabled,
KeepaliveInterval: &keepaliveInterval,
CheckInterval: &oneHour,
BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool {
controlledLoggingFired.Store(true)
return false
},
})
require.NoError(t, err)
require.NoError(t, controllingAgent.OnConnectionStateChange(aNotifier))

controlledAgent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4},
MulticastDNSMode: MulticastDNSModeDisabled,
KeepaliveInterval: &keepaliveInterval,
CheckInterval: &oneHour,
BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool {
// Don't switch candidate pair until we are ready
val, ok := switchToNewCandidatePair.Load().(bool)
return ok && val
},
})
require.NoError(t, err)
require.NoError(t, controlledAgent.OnConnectionStateChange(bNotifier))

controlledConn, controllingConn := connect(controlledAgent, controllingAgent)
<-aConnected
<-bConnected

// Assert we have connected and can send data
require.True(t, sendUntilDone(t, controlledConn, controllingConn, 100))

// Take the lock on the controlling Agent and unset state
assert.NoError(t, controlledAgent.run(controlledAgent.context(), func(_ context.Context, controlledAgent *Agent) {
for net, cs := range controlledAgent.remoteCandidates {
for _, c := range cs {
require.NoError(t, c.close())
}
delete(controlledAgent.remoteCandidates, net)
}

for _, c := range controlledAgent.localCandidates[NetworkTypeUDP4] {
cast, ok := c.(*CandidateHost)
require.True(t, ok)
cast.remoteCandidateCaches = map[AddrPort]Candidate{}
}

controlledAgent.setSelectedPair(nil)
controlledAgent.checklist = make([]*CandidatePair, 0)
}))

// Assert that Selected Candidate pair has only been unset on Controlled side
candidatePair, err := controlledAgent.GetSelectedCandidatePair()
assert.Nil(t, candidatePair)
assert.NoError(t, err)

candidatePair, err = controllingAgent.GetSelectedCandidatePair()
assert.NotNil(t, candidatePair)
assert.NoError(t, err)

// Sending will fail, we no longer have a selected candidate pair
require.False(t, sendUntilDone(t, controlledConn, controllingConn, 20))

// Send STUN Binding requests until a new Selected Candidate Pair has been set by BindingRequestHandler
switchToNewCandidatePair.Store(true)
for {
controllingAgent.requestConnectivityCheck()

candidatePair, err = controlledAgent.GetSelectedCandidatePair()
require.NoError(t, err)
if candidatePair != nil {
break
}

time.Sleep(time.Millisecond * 5)
}

// We have a new selected candidate pair because of BindingRequestHandler, test that it works
require.True(t, sendUntilDone(t, controllingConn, controlledConn, 100))

fired, ok := controlledLoggingFired.Load().(bool)
require.True(t, ok)
require.True(t, fired)

closePipe(t, controllingConn, controlledConn)
}

0 comments on commit da3175f

Please sign in to comment.