package ice import ( "io" "net" "os" "strings" "sync" "github.com/pion/logging" "github.com/pion/stun" ) // UDPMux allows multiple connections to go over a single UDP port type UDPMux interface { io.Closer GetConn(ufrag string) (net.PacketConn, error) RemoveConnByUfrag(ufrag string) } // UDPMuxDefault is an implementation of the interface type UDPMuxDefault struct { params UDPMuxParams closedChan chan struct{} closeOnce sync.Once // conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType conns map[string]*udpMuxedConn addressMapMu sync.RWMutex addressMap map[string]*udpMuxedConn // buffer pool to recycle buffers for net.UDPAddr encodes/decodes pool *sync.Pool mu sync.Mutex } const maxAddrSize = 512 // UDPMuxParams are parameters for UDPMux. type UDPMuxParams struct { Logger logging.LeveledLogger UDPConn net.PacketConn } // NewUDPMuxDefault creates an implementation of UDPMux func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { if params.Logger == nil { params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") } m := &UDPMuxDefault{ addressMap: map[string]*udpMuxedConn{}, params: params, conns: make(map[string]*udpMuxedConn), closedChan: make(chan struct{}, 1), pool: &sync.Pool{ New: func() interface{} { // big enough buffer to fit both packet and address return newBufferHolder(receiveMTU + maxAddrSize) }, }, } go m.connWorker() return m } // LocalAddr returns the listening address of this UDPMuxDefault func (m *UDPMuxDefault) LocalAddr() net.Addr { return m.params.UDPConn.LocalAddr() } // GetConn returns a PacketConn given the connection's ufrag and network // creates the connection if an existing one can't be found func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) { m.mu.Lock() defer m.mu.Unlock() if m.IsClosed() { return nil, io.ErrClosedPipe } if c, ok := m.conns[ufrag]; ok { return c, nil } c := m.createMuxedConn(ufrag) go func() { <-c.CloseChannel() m.removeConn(ufrag) }() m.conns[ufrag] = c return c, nil } // RemoveConnByUfrag stops and removes the muxed packet connection func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { m.mu.Lock() removedConns := make([]*udpMuxedConn, 0) for key := range m.conns { if key != ufrag { continue } c := m.conns[key] delete(m.conns, key) if c != nil { removedConns = append(removedConns, c) } } // keep lock section small to avoid deadlock with conn lock m.mu.Unlock() m.addressMapMu.Lock() defer m.addressMapMu.Unlock() for _, c := range removedConns { addresses := c.getAddresses() for _, addr := range addresses { delete(m.addressMap, addr) } } } // IsClosed returns true if the mux had been closed func (m *UDPMuxDefault) IsClosed() bool { select { case <-m.closedChan: return true default: return false } } // Close the mux, no further connections could be created func (m *UDPMuxDefault) Close() error { var err error m.closeOnce.Do(func() { m.mu.Lock() defer m.mu.Unlock() for _, c := range m.conns { _ = c.Close() } m.conns = make(map[string]*udpMuxedConn) close(m.closedChan) }) return err } func (m *UDPMuxDefault) removeConn(key string) { m.mu.Lock() c := m.conns[key] delete(m.conns, key) // keep lock section small to avoid deadlock with conn lock m.mu.Unlock() if c == nil { return } m.addressMapMu.Lock() defer m.addressMapMu.Unlock() addresses := c.getAddresses() for _, addr := range addresses { delete(m.addressMap, addr) } } func (m *UDPMuxDefault) writeTo(buf []byte, raddr net.Addr) (n int, err error) { return m.params.UDPConn.WriteTo(buf, raddr) } func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) { if m.IsClosed() { return } m.addressMapMu.Lock() defer m.addressMapMu.Unlock() existing, ok := m.addressMap[addr] if ok { existing.removeAddress(addr) } m.addressMap[addr] = conn m.params.Logger.Debugf("Registered %s for %s", addr, conn.params.Key) } func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { c := newUDPMuxedConn(&udpMuxedConnParams{ Mux: m, Key: key, AddrPool: m.pool, LocalAddr: m.LocalAddr(), Logger: m.params.Logger, }) return c } func (m *UDPMuxDefault) connWorker() { logger := m.params.Logger defer func() { _ = m.Close() }() buf := make([]byte, receiveMTU) for { n, addr, err := m.params.UDPConn.ReadFrom(buf) if m.IsClosed() { return } else if err != nil { if os.IsTimeout(err) { continue } else if err != io.EOF { logger.Errorf("could not read udp packet: %v", err) } return } udpAddr, ok := addr.(*net.UDPAddr) if !ok { logger.Errorf("underlying PacketConn did not return a UDPAddr") return } // If we have already seen this address dispatch to the appropriate destination m.addressMapMu.Lock() destinationConn := m.addressMap[addr.String()] m.addressMapMu.Unlock() // If we haven't seen this address before but is a STUN packet lookup by ufrag if destinationConn == nil && stun.IsMessage(buf[:n]) { msg := &stun.Message{ Raw: append([]byte{}, buf[:n]...), } if err = msg.Decode(); err != nil { m.params.Logger.Warnf("Failed to handle decode ICE from %s: %v\n", addr.String(), err) continue } attr, stunAttrErr := msg.Get(stun.AttrUsername) if stunAttrErr != nil { m.params.Logger.Warnf("No Username attribute in STUN message from %s\n", addr.String()) continue } ufrag := strings.Split(string(attr), ":")[0] m.mu.Lock() destinationConn = m.conns[ufrag] m.mu.Unlock() } if destinationConn == nil { m.params.Logger.Tracef("dropping packet from %s, addr: %s", udpAddr.String(), addr.String()) continue } if err = destinationConn.writePacket(buf[:n], udpAddr); err != nil { m.params.Logger.Errorf("could not write packet: %v", err) } } } type bufferHolder struct { buffer []byte } func newBufferHolder(size int) *bufferHolder { return &bufferHolder{ buffer: make([]byte, size), } }