// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package ice import ( "io" "net" "sync" "time" "github.com/pion/logging" ) type udpMuxedConnState int const ( udpMuxedConnOpen udpMuxedConnState = iota udpMuxedConnWaiting udpMuxedConnClosed ) type udpMuxedConnAddr struct { ip [16]byte port uint16 } func newUDPMuxedConnAddr(addr *net.UDPAddr) (a udpMuxedConnAddr) { copy(a.ip[:], addr.IP.To16()) a.port = uint16(addr.Port) return a } type udpMuxedConnParams struct { Mux *UDPMuxDefault AddrPool *sync.Pool Key string LocalAddr net.Addr Logger logging.LeveledLogger } // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag type udpMuxedConn struct { params *udpMuxedConnParams // Remote addresses that we have sent to on this conn addresses []udpMuxedConnAddr // FIFO queue holding incoming packets bufHead, bufTail *bufferHolder notify chan struct{} closedChan chan struct{} state udpMuxedConnState mu sync.Mutex } func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn { return &udpMuxedConn{ params: params, notify: make(chan struct{}, 1), closedChan: make(chan struct{}), } } func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) { for { c.mu.Lock() if c.bufTail != nil { pkt := c.bufTail c.bufTail = pkt.next if pkt == c.bufHead { c.bufHead = nil } c.mu.Unlock() if len(b) < len(pkt.buf) { err = io.ErrShortBuffer } else { n = copy(b, pkt.buf) rAddr = pkt.addr } pkt.reset() c.params.AddrPool.Put(pkt) return } if c.state == udpMuxedConnClosed { c.mu.Unlock() return 0, nil, io.EOF } c.state = udpMuxedConnWaiting c.mu.Unlock() select { case <-c.notify: case <-c.closedChan: return 0, nil, io.EOF } } } func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { if c.isClosed() { return 0, io.ErrClosedPipe } // Each time we write to a new address, we'll register it with the mux addr := newUDPMuxedConnAddr(rAddr.(*net.UDPAddr)) if !c.containsAddress(addr) { c.addAddress(addr) } return c.params.Mux.writeTo(buf, rAddr) } func (c *udpMuxedConn) LocalAddr() net.Addr { return c.params.LocalAddr } func (c *udpMuxedConn) SetDeadline(time.Time) error { return nil } func (c *udpMuxedConn) SetReadDeadline(time.Time) error { return nil } func (c *udpMuxedConn) SetWriteDeadline(time.Time) error { return nil } func (c *udpMuxedConn) CloseChannel() <-chan struct{} { return c.closedChan } func (c *udpMuxedConn) Close() error { c.mu.Lock() defer c.mu.Unlock() if c.state != udpMuxedConnClosed { for pkt := c.bufTail; pkt != nil; { next := pkt.next pkt.reset() c.params.AddrPool.Put(pkt) pkt = next } c.bufHead = nil c.bufTail = nil c.state = udpMuxedConnClosed close(c.closedChan) } return nil } func (c *udpMuxedConn) isClosed() bool { c.mu.Lock() defer c.mu.Unlock() return c.state == udpMuxedConnClosed } func (c *udpMuxedConn) getAddresses() []udpMuxedConnAddr { c.mu.Lock() defer c.mu.Unlock() addresses := make([]udpMuxedConnAddr, len(c.addresses)) copy(addresses, c.addresses) return addresses } func (c *udpMuxedConn) addAddress(addr udpMuxedConnAddr) { c.mu.Lock() c.addresses = append(c.addresses, addr) c.mu.Unlock() // Map it on mux c.params.Mux.registerConnForAddress(c, addr) } func (c *udpMuxedConn) removeAddress(addr udpMuxedConnAddr) { c.mu.Lock() defer c.mu.Unlock() newAddresses := make([]udpMuxedConnAddr, 0, len(c.addresses)) for _, a := range c.addresses { if a != addr { newAddresses = append(newAddresses, a) } } c.addresses = newAddresses } func (c *udpMuxedConn) containsAddress(addr udpMuxedConnAddr) bool { c.mu.Lock() defer c.mu.Unlock() for _, a := range c.addresses { if addr == a { return true } } return false } func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error { pkt := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert if cap(pkt.buf) < len(data) { c.params.AddrPool.Put(pkt) return io.ErrShortBuffer } pkt.buf = append(pkt.buf[:0], data...) pkt.addr = addr c.mu.Lock() if c.state == udpMuxedConnClosed { c.mu.Unlock() pkt.reset() c.params.AddrPool.Put(pkt) return io.ErrClosedPipe } if c.bufHead != nil { c.bufHead.next = pkt } c.bufHead = pkt if c.bufTail == nil { c.bufTail = pkt } state := c.state c.state = udpMuxedConnOpen c.mu.Unlock() if state == udpMuxedConnWaiting { select { case c.notify <- struct{}{}: default: } } return nil }