1262 lines
32 KiB
Go
Raw Normal View History

2022-03-10 10:44:48 +01:00
// Package ice implements the Interactive Connectivity Establishment (ICE)
// protocol defined in rfc5245.
package ice
import (
"context"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pion/logging"
"github.com/pion/mdns"
"github.com/pion/stun"
"github.com/pion/transport/packetio"
"github.com/pion/transport/vnet"
2022-03-10 10:44:48 +01:00
"golang.org/x/net/proxy"
)
type bindingRequest struct {
timestamp time.Time
transactionID [stun.TransactionIDSize]byte
destination net.Addr
isUseCandidate bool
}
// Agent represents the ICE agent
type Agent struct {
chanTask chan task
afterRunFn []func(ctx context.Context)
muAfterRun sync.Mutex
onConnectionStateChangeHdlr atomic.Value // func(ConnectionState)
onSelectedCandidatePairChangeHdlr atomic.Value // func(Candidate, Candidate)
onCandidateHdlr atomic.Value // func(Candidate)
// State owned by the taskLoop
onConnected chan struct{}
onConnectedOnce sync.Once
// force candidate to be contacted immediately (instead of waiting for task ticker)
2022-03-10 10:44:48 +01:00
forceCandidateContact chan bool
tieBreaker uint64
lite bool
connectionState ConnectionState
gatheringState GatheringState
mDNSMode MulticastDNSMode
mDNSName string
mDNSConn *mdns.Conn
muHaveStarted sync.Mutex
startedCh <-chan struct{}
startedFn func()
isControlling bool
maxBindingRequests uint16
hostAcceptanceMinWait time.Duration
srflxAcceptanceMinWait time.Duration
prflxAcceptanceMinWait time.Duration
relayAcceptanceMinWait time.Duration
portmin uint16
portmax uint16
2022-03-10 10:44:48 +01:00
candidateTypes []CandidateType
// How long connectivity checks can fail before the ICE Agent
// goes to disconnected
disconnectedTimeout time.Duration
// How long connectivity checks can fail before the ICE Agent
// goes to failed
failedTimeout time.Duration
// How often should we send keepalive packets?
// 0 means never
keepaliveInterval time.Duration
// How often should we run our internal taskLoop to check for state changes when connecting
checkInterval time.Duration
localUfrag string
localPwd string
localCandidates map[NetworkType][]Candidate
remoteUfrag string
remotePwd string
remoteCandidates map[NetworkType][]Candidate
checklist []*CandidatePair
selector pairCandidateSelector
selectedPair atomic.Value // *CandidatePair
urls []*URL
2022-03-10 10:44:48 +01:00
networkTypes []NetworkType
buffer *packetio.Buffer
2022-03-10 10:44:48 +01:00
// LRU of outbound Binding request Transaction IDs
pendingBindingRequests []bindingRequest
// 1:1 D-NAT IP address mapping
extIPMapper *externalIPMapper
// State for closing
done chan struct{}
taskLoopDone chan struct{}
err atomicError
2022-03-10 10:44:48 +01:00
gatherCandidateCancel func()
chanCandidate chan Candidate
chanCandidatePair chan *CandidatePair
chanState chan ConnectionState
loggerFactory logging.LoggerFactory
log logging.LeveledLogger
net *vnet.Net
tcpMux TCPMux
udpMux UDPMux
2022-03-10 10:44:48 +01:00
interfaceFilter func(string) bool
insecureSkipVerify bool
proxyDialer proxy.Dialer
}
type task struct {
fn func(context.Context, *Agent)
done chan struct{}
}
// afterRun registers function to be run after the task.
func (a *Agent) afterRun(f func(context.Context)) {
a.muAfterRun.Lock()
a.afterRunFn = append(a.afterRunFn, f)
a.muAfterRun.Unlock()
}
func (a *Agent) getAfterRunFn() []func(context.Context) {
a.muAfterRun.Lock()
defer a.muAfterRun.Unlock()
fns := a.afterRunFn
a.afterRunFn = nil
return fns
}
func (a *Agent) ok() error {
select {
case <-a.done:
return a.getErr()
default:
}
return nil
}
func (a *Agent) getErr() error {
if err := a.err.Load(); err != nil {
return err
}
return ErrClosed
}
// Run task in serial. Blocking tasks must be cancelable by context.
func (a *Agent) run(ctx context.Context, t func(context.Context, *Agent)) error {
if err := a.ok(); err != nil {
return err
}
done := make(chan struct{})
select {
case <-ctx.Done():
return ctx.Err()
case a.chanTask <- task{t, done}:
<-done
return nil
}
}
// taskLoop handles registered tasks and agent close.
func (a *Agent) taskLoop() {
after := func() {
for {
// Get and run func registered by afterRun().
fns := a.getAfterRunFn()
if len(fns) == 0 {
break
}
for _, fn := range fns {
fn(a.context())
}
}
}
defer func() {
a.deleteAllCandidates()
a.startedFn()
if err := a.buffer.Close(); err != nil {
a.log.Warnf("failed to close buffer: %v", err)
2022-03-10 10:44:48 +01:00
}
a.closeMulticastConn()
a.updateConnectionState(ConnectionStateClosed)
after()
close(a.chanState)
close(a.chanCandidate)
close(a.chanCandidatePair)
close(a.taskLoopDone)
}()
for {
select {
case <-a.done:
return
case t := <-a.chanTask:
t.fn(a.context(), a)
close(t.done)
after()
}
}
}
// NewAgent creates a new Agent
func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
var err error
if config.PortMax < config.PortMin {
return nil, ErrPort
}
mDNSName := config.MulticastDNSHostName
if mDNSName == "" {
if mDNSName, err = generateMulticastDNSName(); err != nil {
return nil, err
}
}
if !strings.HasSuffix(mDNSName, ".local") || len(strings.Split(mDNSName, ".")) != 2 {
return nil, ErrInvalidMulticastDNSHostName
}
mDNSMode := config.MulticastDNSMode
if mDNSMode == 0 {
mDNSMode = MulticastDNSModeQueryOnly
}
loggerFactory := config.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
}
log := loggerFactory.NewLogger("ice")
var mDNSConn *mdns.Conn
mDNSConn, mDNSMode, err = createMulticastDNS(mDNSMode, mDNSName, log)
// Opportunistic mDNS: If we can't open the connection, that's ok: we
// can continue without it.
if err != nil {
log.Warnf("Failed to initialize mDNS %s: %v", mDNSName, err)
}
closeMDNSConn := func() {
if mDNSConn != nil {
if mdnsCloseErr := mDNSConn.Close(); mdnsCloseErr != nil {
log.Warnf("Failed to close mDNS: %v", mdnsCloseErr)
}
}
}
2022-03-10 10:44:48 +01:00
startedCtx, startedFn := context.WithCancel(context.Background())
a := &Agent{
chanTask: make(chan task),
chanState: make(chan ConnectionState),
chanCandidate: make(chan Candidate),
chanCandidatePair: make(chan *CandidatePair),
tieBreaker: globalMathRandomGenerator.Uint64(),
lite: config.Lite,
gatheringState: GatheringStateNew,
connectionState: ConnectionStateNew,
localCandidates: make(map[NetworkType][]Candidate),
remoteCandidates: make(map[NetworkType][]Candidate),
urls: config.Urls,
networkTypes: config.NetworkTypes,
onConnected: make(chan struct{}),
buffer: packetio.NewBuffer(),
2022-03-10 10:44:48 +01:00
done: make(chan struct{}),
taskLoopDone: make(chan struct{}),
startedCh: startedCtx.Done(),
startedFn: startedFn,
portmin: config.PortMin,
portmax: config.PortMax,
2022-03-10 10:44:48 +01:00
loggerFactory: loggerFactory,
log: log,
net: config.Net,
proxyDialer: config.ProxyDialer,
mDNSMode: mDNSMode,
mDNSName: mDNSName,
mDNSConn: mDNSConn,
2022-03-10 10:44:48 +01:00
gatherCandidateCancel: func() {},
forceCandidateContact: make(chan bool, 1),
interfaceFilter: config.InterfaceFilter,
insecureSkipVerify: config.InsecureSkipVerify,
}
2022-03-10 10:44:48 +01:00
a.tcpMux = config.TCPMux
if a.tcpMux == nil {
a.tcpMux = newInvalidTCPMux()
2022-03-10 10:44:48 +01:00
}
a.udpMux = config.UDPMux
2022-03-10 10:44:48 +01:00
if a.net == nil {
a.net = vnet.NewNet(nil)
} else if a.net.IsVirtual() {
a.log.Warn("vnet is enabled")
2022-03-10 10:44:48 +01:00
if a.mDNSMode != MulticastDNSModeDisabled {
a.log.Warn("vnet does not support mDNS yet")
2022-03-10 10:44:48 +01:00
}
}
config.initWithDefaults(a)
// Make sure the buffer doesn't grow indefinitely.
// NOTE: We actually won't get anywhere close to this limit.
// SRTP will constantly read from the endpoint and drop packets if it's full.
a.buffer.SetLimitSize(maxBufferSize)
2022-03-10 10:44:48 +01:00
if a.lite && (len(a.candidateTypes) != 1 || a.candidateTypes[0] != CandidateTypeHost) {
closeMDNSConn()
2022-03-10 10:44:48 +01:00
return nil, ErrLiteUsingNonHostCandidates
}
if config.Urls != nil && len(config.Urls) > 0 && !containsCandidateType(CandidateTypeServerReflexive, a.candidateTypes) && !containsCandidateType(CandidateTypeRelay, a.candidateTypes) {
closeMDNSConn()
2022-03-10 10:44:48 +01:00
return nil, ErrUselessUrlsProvided
}
if err = config.initExtIPMapping(a); err != nil {
closeMDNSConn()
2022-03-10 10:44:48 +01:00
return nil, err
}
go a.taskLoop()
a.startOnConnectionStateChangeRoutine()
2022-03-10 10:44:48 +01:00
// Restart is also used to initialize the agent for the first time
if err := a.Restart(config.LocalUfrag, config.LocalPwd); err != nil {
closeMDNSConn()
2022-03-10 10:44:48 +01:00
_ = a.Close()
return nil, err
}
return a, nil
}
// OnConnectionStateChange sets a handler that is fired when the connection state changes
func (a *Agent) OnConnectionStateChange(f func(ConnectionState)) error {
a.onConnectionStateChangeHdlr.Store(f)
return nil
}
// OnSelectedCandidatePairChange sets a handler that is fired when the final candidate
// pair is selected
func (a *Agent) OnSelectedCandidatePairChange(f func(Candidate, Candidate)) error {
a.onSelectedCandidatePairChangeHdlr.Store(f)
return nil
}
// OnCandidate sets a handler that is fired when new candidates gathered. When
// the gathering process complete the last candidate is nil.
func (a *Agent) OnCandidate(f func(Candidate)) error {
a.onCandidateHdlr.Store(f)
return nil
}
func (a *Agent) onSelectedCandidatePairChange(p *CandidatePair) {
if h, ok := a.onSelectedCandidatePairChangeHdlr.Load().(func(Candidate, Candidate)); ok {
h(p.Local, p.Remote)
}
}
func (a *Agent) onCandidate(c Candidate) {
if onCandidateHdlr, ok := a.onCandidateHdlr.Load().(func(Candidate)); ok {
onCandidateHdlr(c)
}
}
func (a *Agent) onConnectionStateChange(s ConnectionState) {
if hdlr, ok := a.onConnectionStateChangeHdlr.Load().(func(ConnectionState)); ok {
hdlr(s)
}
}
func (a *Agent) startOnConnectionStateChangeRoutine() {
go func() {
for {
// CandidatePair and ConnectionState are usually changed at once.
// Blocking one by the other one causes deadlock.
p, isOpen := <-a.chanCandidatePair
if !isOpen {
return
}
a.onSelectedCandidatePairChange(p)
}
}()
go func() {
for {
select {
case s, isOpen := <-a.chanState:
if !isOpen {
for c := range a.chanCandidate {
a.onCandidate(c)
}
return
}
go a.onConnectionStateChange(s)
case c, isOpen := <-a.chanCandidate:
if !isOpen {
for s := range a.chanState {
go a.onConnectionStateChange(s)
}
return
}
a.onCandidate(c)
}
}
}()
}
2022-03-10 10:44:48 +01:00
func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remotePwd string) error {
a.muHaveStarted.Lock()
defer a.muHaveStarted.Unlock()
select {
case <-a.startedCh:
return ErrMultipleStart
default:
}
if err := a.SetRemoteCredentials(remoteUfrag, remotePwd); err != nil {
2022-03-10 10:44:48 +01:00
return err
}
a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd)
return a.run(a.context(), func(ctx context.Context, agent *Agent) {
agent.isControlling = isControlling
agent.remoteUfrag = remoteUfrag
agent.remotePwd = remotePwd
if isControlling {
a.selector = &controllingSelector{agent: a, log: a.log}
} else {
a.selector = &controlledSelector{agent: a, log: a.log}
}
if a.lite {
a.selector = &liteSelector{pairCandidateSelector: a.selector}
}
a.selector.Start()
a.startedFn()
agent.updateConnectionState(ConnectionStateChecking)
a.requestConnectivityCheck()
go a.connectivityChecks()
2022-03-10 10:44:48 +01:00
})
}
func (a *Agent) connectivityChecks() {
lastConnectionState := ConnectionState(0)
checkingDuration := time.Time{}
contact := func() {
if err := a.run(a.context(), func(ctx context.Context, a *Agent) {
defer func() {
lastConnectionState = a.connectionState
}()
switch a.connectionState {
case ConnectionStateFailed:
// The connection is currently failed so don't send any checks
// In the future it may be restarted though
return
case ConnectionStateChecking:
// We have just entered checking for the first time so update our checking timer
if lastConnectionState != a.connectionState {
checkingDuration = time.Now()
}
// We have been in checking longer then Disconnect+Failed timeout, set the connection to Failed
if time.Since(checkingDuration) > a.disconnectedTimeout+a.failedTimeout {
a.updateConnectionState(ConnectionStateFailed)
return
}
}
a.selector.ContactCandidates()
}); err != nil {
a.log.Warnf("taskLoop failed: %v", err)
2022-03-10 10:44:48 +01:00
}
}
for {
interval := defaultKeepaliveInterval
updateInterval := func(x time.Duration) {
if x != 0 && (interval == 0 || interval > x) {
interval = x
}
}
switch lastConnectionState {
case ConnectionStateNew, ConnectionStateChecking: // While connecting, check candidates more frequently
updateInterval(a.checkInterval)
case ConnectionStateConnected, ConnectionStateDisconnected:
updateInterval(a.keepaliveInterval)
default:
}
// Ensure we run our task loop as quickly as the minimum of our various configured timeouts
updateInterval(a.disconnectedTimeout)
updateInterval(a.failedTimeout)
t := time.NewTimer(interval)
select {
case <-a.forceCandidateContact:
t.Stop()
contact()
case <-t.C:
contact()
case <-a.done:
t.Stop()
return
}
}
}
func (a *Agent) updateConnectionState(newState ConnectionState) {
if a.connectionState != newState {
// Connection has gone to failed, release all gathered candidates
if newState == ConnectionStateFailed {
a.deleteAllCandidates()
}
a.log.Infof("Setting new connection state: %s", newState)
a.connectionState = newState
// Call handler after finishing current task since we may be holding the agent lock
// and the handler may also require it
a.afterRun(func(ctx context.Context) {
a.chanState <- newState
})
}
}
func (a *Agent) setSelectedPair(p *CandidatePair) {
a.log.Tracef("Set selected candidate pair: %s", p)
2022-03-10 10:44:48 +01:00
if p == nil {
var nilPair *CandidatePair
a.selectedPair.Store(nilPair)
return
}
p.nominated = true
a.selectedPair.Store(p)
a.updateConnectionState(ConnectionStateConnected)
// Notify when the selected pair changes
if p != nil {
a.afterRun(func(ctx context.Context) {
select {
case a.chanCandidatePair <- p:
case <-ctx.Done():
}
})
}
2022-03-10 10:44:48 +01:00
// Signal connected
a.onConnectedOnce.Do(func() { close(a.onConnected) })
}
func (a *Agent) pingAllCandidates() {
a.log.Trace("pinging all candidates")
if len(a.checklist) == 0 {
a.log.Warn("pingAllCandidates called with no candidate pairs. Connection is not possible yet.")
2022-03-10 10:44:48 +01:00
}
for _, p := range a.checklist {
if p.state == CandidatePairStateWaiting {
p.state = CandidatePairStateInProgress
} else if p.state != CandidatePairStateInProgress {
continue
}
if p.bindingRequestCount > a.maxBindingRequests {
a.log.Tracef("max requests reached for pair %s, marking it as failed\n", p)
2022-03-10 10:44:48 +01:00
p.state = CandidatePairStateFailed
} else {
a.selector.PingCandidate(p.Local, p.Remote)
p.bindingRequestCount++
}
}
}
func (a *Agent) getBestAvailableCandidatePair() *CandidatePair {
var best *CandidatePair
for _, p := range a.checklist {
if p.state == CandidatePairStateFailed {
continue
}
if best == nil {
best = p
} else if best.priority() < p.priority() {
best = p
}
}
return best
}
func (a *Agent) getBestValidCandidatePair() *CandidatePair {
var best *CandidatePair
for _, p := range a.checklist {
if p.state != CandidatePairStateSucceeded {
continue
}
if best == nil {
best = p
} else if best.priority() < p.priority() {
best = p
}
}
return best
}
func (a *Agent) addPair(local, remote Candidate) *CandidatePair {
p := newCandidatePair(local, remote, a.isControlling)
a.checklist = append(a.checklist, p)
return p
}
func (a *Agent) findPair(local, remote Candidate) *CandidatePair {
for _, p := range a.checklist {
if p.Local.Equal(local) && p.Remote.Equal(remote) {
return p
}
}
return nil
}
// validateSelectedPair checks if the selected pair is (still) valid
// Note: the caller should hold the agent lock.
func (a *Agent) validateSelectedPair() bool {
selectedPair := a.getSelectedPair()
if selectedPair == nil {
return false
}
disconnectedTime := time.Since(selectedPair.Remote.LastReceived())
// Only allow transitions to failed if a.failedTimeout is non-zero
totalTimeToFailure := a.failedTimeout
if totalTimeToFailure != 0 {
totalTimeToFailure += a.disconnectedTimeout
}
switch {
case totalTimeToFailure != 0 && disconnectedTime > totalTimeToFailure:
a.updateConnectionState(ConnectionStateFailed)
case a.disconnectedTimeout != 0 && disconnectedTime > a.disconnectedTimeout:
a.updateConnectionState(ConnectionStateDisconnected)
default:
a.updateConnectionState(ConnectionStateConnected)
}
return true
}
// checkKeepalive sends STUN Binding Indications to the selected pair
// if no packet has been sent on that pair in the last keepaliveInterval
// Note: the caller should hold the agent lock.
func (a *Agent) checkKeepalive() {
selectedPair := a.getSelectedPair()
if selectedPair == nil {
return
}
if (a.keepaliveInterval != 0) &&
((time.Since(selectedPair.Local.LastSent()) > a.keepaliveInterval) ||
(time.Since(selectedPair.Remote.LastReceived()) > a.keepaliveInterval)) {
// we use binding request instead of indication to support refresh consent schemas
2022-03-10 10:44:48 +01:00
// see https://tools.ietf.org/html/rfc7675
a.selector.PingCandidate(selectedPair.Local, selectedPair.Remote)
}
}
// AddRemoteCandidate adds a new remote candidate
func (a *Agent) AddRemoteCandidate(c Candidate) error {
if c == nil {
return nil
}
// cannot check for network yet because it might not be applied
// when mDNS hostame is used.
2022-03-10 10:44:48 +01:00
if c.TCPType() == TCPTypeActive {
// TCP Candidates with tcptype active will probe server passive ones, so
2022-03-10 10:44:48 +01:00
// no need to do anything with them.
a.log.Infof("Ignoring remote candidate with tcpType active: %s", c)
return nil
}
// If we have a mDNS Candidate lets fully resolve it before adding it locally
if c.Type() == CandidateTypeHost && strings.HasSuffix(c.Address(), ".local") {
if a.mDNSMode == MulticastDNSModeDisabled {
a.log.Warnf("remote mDNS candidate added, but mDNS is disabled: (%s)", c.Address())
2022-03-10 10:44:48 +01:00
return nil
}
hostCandidate, ok := c.(*CandidateHost)
if !ok {
return ErrAddressParseFailed
}
go a.resolveAndAddMulticastCandidate(hostCandidate)
return nil
}
go func() {
if err := a.run(a.context(), func(ctx context.Context, agent *Agent) {
agent.addRemoteCandidate(c)
}); err != nil {
a.log.Warnf("Failed to add remote candidate %s: %v", c.Address(), err)
return
}
}()
return nil
}
func (a *Agent) resolveAndAddMulticastCandidate(c *CandidateHost) {
if a.mDNSConn == nil {
return
}
_, src, err := a.mDNSConn.Query(c.context(), c.Address())
if err != nil {
a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err)
return
}
ip, _, _, _ := parseAddr(src) //nolint:dogsled
if ip == nil {
2022-03-10 10:44:48 +01:00
a.log.Warnf("Failed to discover mDNS candidate %s: failed to parse IP", c.Address())
return
}
if err = c.setIP(ip); err != nil {
a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err)
return
}
if err = a.run(a.context(), func(ctx context.Context, agent *Agent) {
agent.addRemoteCandidate(c)
}); err != nil {
a.log.Warnf("Failed to add mDNS candidate %s: %v", c.Address(), err)
return
}
}
func (a *Agent) requestConnectivityCheck() {
select {
case a.forceCandidateContact <- true:
default:
}
}
// addRemoteCandidate assumes you are holding the lock (must be execute using a.run)
func (a *Agent) addRemoteCandidate(c Candidate) {
set := a.remoteCandidates[c.NetworkType()]
for _, candidate := range set {
if candidate.Equal(c) {
return
}
}
set = append(set, c)
a.remoteCandidates[c.NetworkType()] = set
if localCandidates, ok := a.localCandidates[c.NetworkType()]; ok {
for _, localCandidate := range localCandidates {
a.addPair(localCandidate, c)
}
}
a.requestConnectivityCheck()
}
func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net.PacketConn) error {
return a.run(ctx, func(ctx context.Context, agent *Agent) {
set := a.localCandidates[c.NetworkType()]
for _, candidate := range set {
if candidate.Equal(c) {
a.log.Debugf("Ignore duplicate candidate: %s", c.String())
if err := c.close(); err != nil {
a.log.Warnf("Failed to close duplicate candidate: %v", err)
}
return
}
}
c.start(a, candidateConn, a.startedCh)
set = append(set, c)
a.localCandidates[c.NetworkType()] = set
if remoteCandidates, ok := a.remoteCandidates[c.NetworkType()]; ok {
for _, remoteCandidate := range remoteCandidates {
a.addPair(c, remoteCandidate)
}
}
a.requestConnectivityCheck()
a.chanCandidate <- c
})
}
// GetLocalCandidates returns the local candidates
func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
var res []Candidate
err := a.run(a.context(), func(ctx context.Context, agent *Agent) {
var candidates []Candidate
for _, set := range agent.localCandidates {
candidates = append(candidates, set...)
}
res = candidates
})
if err != nil {
return nil, err
}
return res, nil
}
// GetLocalUserCredentials returns the local user credentials
func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
valSet := make(chan struct{})
err = a.run(a.context(), func(ctx context.Context, agent *Agent) {
frag = agent.localUfrag
pwd = agent.localPwd
close(valSet)
})
if err == nil {
<-valSet
}
return
}
// GetRemoteUserCredentials returns the remote user credentials
func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error) {
valSet := make(chan struct{})
err = a.run(a.context(), func(ctx context.Context, agent *Agent) {
frag = agent.remoteUfrag
pwd = agent.remotePwd
close(valSet)
})
if err == nil {
<-valSet
}
return
}
func (a *Agent) removeUfragFromMux() {
a.tcpMux.RemoveConnByUfrag(a.localUfrag)
2022-03-10 10:44:48 +01:00
if a.udpMux != nil {
a.udpMux.RemoveConnByUfrag(a.localUfrag)
}
}
// Close cleans up the Agent
func (a *Agent) Close() error {
if err := a.ok(); err != nil {
return err
}
a.afterRun(func(context.Context) {
a.gatherCandidateCancel()
})
a.err.Store(ErrClosed)
a.removeUfragFromMux()
close(a.done)
<-a.taskLoopDone
return nil
}
// Remove all candidates. This closes any listening sockets
// and removes both the local and remote candidate lists.
//
// This is used for restarts, failures and on close
func (a *Agent) deleteAllCandidates() {
for net, cs := range a.localCandidates {
for _, c := range cs {
if err := c.close(); err != nil {
a.log.Warnf("Failed to close candidate %s: %v", c, err)
}
}
delete(a.localCandidates, net)
}
for net, cs := range a.remoteCandidates {
for _, c := range cs {
if err := c.close(); err != nil {
a.log.Warnf("Failed to close candidate %s: %v", c, err)
}
}
delete(a.remoteCandidates, net)
}
}
func (a *Agent) findRemoteCandidate(networkType NetworkType, addr net.Addr) Candidate {
ip, port, _, ok := parseAddr(addr)
if !ok {
a.log.Warnf("Error parsing addr: %s", addr)
2022-03-10 10:44:48 +01:00
return nil
}
set := a.remoteCandidates[networkType]
for _, c := range set {
if c.Address() == ip.String() && c.Port() == port {
return c
}
}
return nil
}
func (a *Agent) sendBindingRequest(m *stun.Message, local, remote Candidate) {
a.log.Tracef("ping STUN from %s to %s\n", local.String(), remote.String())
2022-03-10 10:44:48 +01:00
a.invalidatePendingBindingRequests(time.Now())
a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{
timestamp: time.Now(),
transactionID: m.TransactionID,
destination: remote.addr(),
isUseCandidate: m.Contains(stun.AttrUseCandidate),
})
a.sendSTUN(m, local, remote)
}
func (a *Agent) sendBindingSuccess(m *stun.Message, local, remote Candidate) {
base := remote
ip, port, _, ok := parseAddr(base.addr())
if !ok {
a.log.Warnf("Error parsing addr: %s", base.addr())
2022-03-10 10:44:48 +01:00
return
}
if out, err := stun.Build(m, stun.BindingSuccess,
&stun.XORMappedAddress{
IP: ip,
Port: port,
},
stun.NewShortTermIntegrity(a.localPwd),
stun.Fingerprint,
); err != nil {
a.log.Warnf("Failed to handle inbound ICE from: %s to: %s error: %s", local, remote, err)
} else {
a.sendSTUN(out, local, remote)
}
}
/* Removes pending binding requests that are over maxBindingRequestTimeout old
Let HTO be the transaction timeout, which SHOULD be 2*RTT if
RTT is known or 500 ms otherwise.
https://tools.ietf.org/html/rfc8445#appendix-B.1
*/
2022-03-10 10:44:48 +01:00
func (a *Agent) invalidatePendingBindingRequests(filterTime time.Time) {
initialSize := len(a.pendingBindingRequests)
temp := a.pendingBindingRequests[:0]
for _, bindingRequest := range a.pendingBindingRequests {
if filterTime.Sub(bindingRequest.timestamp) < maxBindingRequestTimeout {
temp = append(temp, bindingRequest)
}
}
a.pendingBindingRequests = temp
if bindRequestsRemoved := initialSize - len(a.pendingBindingRequests); bindRequestsRemoved > 0 {
a.log.Tracef("Discarded %d binding requests because they expired", bindRequestsRemoved)
}
}
// Assert that the passed TransactionID is in our pendingBindingRequests and returns the destination
// If the bindingRequest was valid remove it from our pending cache
func (a *Agent) handleInboundBindingSuccess(id [stun.TransactionIDSize]byte) (bool, *bindingRequest) {
a.invalidatePendingBindingRequests(time.Now())
for i := range a.pendingBindingRequests {
if a.pendingBindingRequests[i].transactionID == id {
validBindingRequest := a.pendingBindingRequests[i]
a.pendingBindingRequests = append(a.pendingBindingRequests[:i], a.pendingBindingRequests[i+1:]...)
return true, &validBindingRequest
}
}
return false, nil
}
// handleInbound processes STUN traffic from a remote candidate
func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr) { //nolint:gocognit
var err error
if m == nil || local == nil {
return
}
if m.Type.Method != stun.MethodBinding ||
!(m.Type.Class == stun.ClassSuccessResponse ||
m.Type.Class == stun.ClassRequest ||
m.Type.Class == stun.ClassIndication) {
a.log.Tracef("unhandled STUN from %s to %s class(%s) method(%s)", remote, local, m.Type.Class, m.Type.Method)
return
}
if a.isControlling {
if m.Contains(stun.AttrICEControlling) {
a.log.Debug("inbound isControlling && a.isControlling == true")
2022-03-10 10:44:48 +01:00
return
} else if m.Contains(stun.AttrUseCandidate) {
a.log.Debug("useCandidate && a.isControlling == true")
2022-03-10 10:44:48 +01:00
return
}
} else {
if m.Contains(stun.AttrICEControlled) {
a.log.Debug("inbound isControlled && a.isControlling == false")
2022-03-10 10:44:48 +01:00
return
}
}
remoteCandidate := a.findRemoteCandidate(local.NetworkType(), remote)
if m.Type.Class == stun.ClassSuccessResponse {
if err = assertInboundMessageIntegrity(m, []byte(a.remotePwd)); err != nil {
a.log.Warnf("discard message from (%s), %v", remote, err)
2022-03-10 10:44:48 +01:00
return
}
if remoteCandidate == nil {
a.log.Warnf("discard success message from (%s), no such remote", remote)
2022-03-10 10:44:48 +01:00
return
}
a.selector.HandleSuccessResponse(m, local, remoteCandidate, remote)
} else if m.Type.Class == stun.ClassRequest {
if err = assertInboundUsername(m, a.localUfrag+":"+a.remoteUfrag); err != nil {
a.log.Warnf("discard message from (%s), %v", remote, err)
2022-03-10 10:44:48 +01:00
return
} else if err = assertInboundMessageIntegrity(m, []byte(a.localPwd)); err != nil {
a.log.Warnf("discard message from (%s), %v", remote, err)
2022-03-10 10:44:48 +01:00
return
}
if remoteCandidate == nil {
ip, port, networkType, ok := parseAddr(remote)
if !ok {
a.log.Errorf("Failed to create parse remote net.Addr when creating remote prflx candidate")
return
}
prflxCandidateConfig := CandidatePeerReflexiveConfig{
Network: networkType.String(),
Address: ip.String(),
Port: port,
Component: local.Component(),
RelAddr: "",
RelPort: 0,
}
prflxCandidate, err := NewCandidatePeerReflexive(&prflxCandidateConfig)
if err != nil {
a.log.Errorf("Failed to create new remote prflx candidate (%s)", err)
return
}
remoteCandidate = prflxCandidate
a.log.Debugf("adding a new peer-reflexive candidate: %s ", remote)
2022-03-10 10:44:48 +01:00
a.addRemoteCandidate(remoteCandidate)
}
a.log.Tracef("inbound STUN (Request) from %s to %s", remote.String(), local.String())
a.selector.HandleBindingRequest(m, local, remoteCandidate)
}
if remoteCandidate != nil {
remoteCandidate.seen(false)
}
}
// validateNonSTUNTraffic processes non STUN traffic from a remote candidate,
// and returns true if it is an actual remote candidate
func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) bool {
var isValidCandidate uint64
2022-03-10 10:44:48 +01:00
if err := a.run(local.context(), func(ctx context.Context, agent *Agent) {
remoteCandidate := a.findRemoteCandidate(local.NetworkType(), remote)
2022-03-10 10:44:48 +01:00
if remoteCandidate != nil {
remoteCandidate.seen(false)
atomic.AddUint64(&isValidCandidate, 1)
2022-03-10 10:44:48 +01:00
}
}); err != nil {
a.log.Warnf("failed to validate remote candidate: %v", err)
2022-03-10 10:44:48 +01:00
}
return atomic.LoadUint64(&isValidCandidate) == 1
2022-03-10 10:44:48 +01:00
}
// GetSelectedCandidatePair returns the selected pair or nil if there is none
func (a *Agent) GetSelectedCandidatePair() (*CandidatePair, error) {
selectedPair := a.getSelectedPair()
if selectedPair == nil {
return nil, nil
2022-03-10 10:44:48 +01:00
}
local, err := selectedPair.Local.copy()
if err != nil {
return nil, err
}
remote, err := selectedPair.Remote.copy()
if err != nil {
return nil, err
}
return &CandidatePair{Local: local, Remote: remote}, nil
}
func (a *Agent) getSelectedPair() *CandidatePair {
selectedPair := a.selectedPair.Load()
if selectedPair == nil {
return nil
2022-03-10 10:44:48 +01:00
}
return selectedPair.(*CandidatePair)
2022-03-10 10:44:48 +01:00
}
func (a *Agent) closeMulticastConn() {
if a.mDNSConn != nil {
if err := a.mDNSConn.Close(); err != nil {
a.log.Warnf("failed to close mDNS Conn: %v", err)
2022-03-10 10:44:48 +01:00
}
}
}
// SetRemoteCredentials sets the credentials of the remote agent
func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error {
switch {
case remoteUfrag == "":
return ErrRemoteUfragEmpty
case remotePwd == "":
return ErrRemotePwdEmpty
}
return a.run(a.context(), func(ctx context.Context, agent *Agent) {
agent.remoteUfrag = remoteUfrag
agent.remotePwd = remotePwd
})
}
// Restart restarts the ICE Agent with the provided ufrag/pwd
// If no ufrag/pwd is provided the Agent will generate one itself
//
// Restart must only be called when GatheringState is GatheringStateComplete
// a user must then call GatherCandidates explicitly to start generating new ones
2022-03-10 10:44:48 +01:00
func (a *Agent) Restart(ufrag, pwd string) error {
if ufrag == "" {
var err error
ufrag, err = generateUFrag()
if err != nil {
return err
}
}
if pwd == "" {
var err error
pwd, err = generatePwd()
if err != nil {
return err
}
}
if len([]rune(ufrag))*8 < 24 {
return ErrLocalUfragInsufficientBits
}
if len([]rune(pwd))*8 < 128 {
return ErrLocalPwdInsufficientBits
}
var err error
if runErr := a.run(a.context(), func(ctx context.Context, agent *Agent) {
if agent.gatheringState == GatheringStateGathering {
err = ErrRestartWhenGathering
return
2022-03-10 10:44:48 +01:00
}
// Clear all agent needed to take back to fresh state
a.removeUfragFromMux()
agent.localUfrag = ufrag
agent.localPwd = pwd
agent.remoteUfrag = ""
agent.remotePwd = ""
a.gatheringState = GatheringStateNew
a.checklist = make([]*CandidatePair, 0)
a.pendingBindingRequests = make([]bindingRequest, 0)
a.setSelectedPair(nil)
a.deleteAllCandidates()
if a.selector != nil {
a.selector.Start()
}
// Restart is used by NewAgent. Accept/Connect should be used to move to checking
// for new Agents
if a.connectionState != ConnectionStateNew {
a.updateConnectionState(ConnectionStateChecking)
}
}); runErr != nil {
return runErr
}
return err
}
func (a *Agent) setGatheringState(newState GatheringState) error {
done := make(chan struct{})
if err := a.run(a.context(), func(ctx context.Context, agent *Agent) {
if a.gatheringState != newState && newState == GatheringStateComplete {
a.chanCandidate <- nil
}
a.gatheringState = newState
close(done)
}); err != nil {
return err
}
<-done
return nil
}