407 lines
9.6 KiB
Go
Raw Normal View History

2024-05-15 19:15:00 -04:00
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
2022-03-10 10:44:48 +01:00
// Package udp provides a connection-oriented listener over a UDP PacketConn
package udp
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
2024-06-05 16:10:03 -04:00
"github.com/pion/logging"
2024-05-15 19:15:00 -04:00
"github.com/pion/transport/v2/deadline"
"github.com/pion/transport/v2/packetio"
2024-06-05 16:10:03 -04:00
"golang.org/x/net/ipv4"
2022-03-10 10:44:48 +01:00
)
const (
receiveMTU = 8192
2024-06-05 16:10:03 -04:00
sendMTU = 1500
2022-03-10 10:44:48 +01:00
defaultListenBacklog = 128 // same as Linux default
)
// Typed errors
var (
ErrClosedListener = errors.New("udp: listener closed")
ErrListenQueueExceeded = errors.New("udp: listen queue exceeded")
2024-06-05 16:10:03 -04:00
ErrInvalidBatchConfig = errors.New("udp: invalid batch config")
2022-03-10 10:44:48 +01:00
)
// listener augments a connection-oriented Listener over a UDP PacketConn
type listener struct {
2024-06-05 16:10:03 -04:00
pConn net.PacketConn
2022-03-10 10:44:48 +01:00
2024-06-05 16:10:03 -04:00
readBatchSize int
accepting atomic.Value // bool
acceptCh chan *Conn
doneCh chan struct{}
doneOnce sync.Once
acceptFilter func([]byte) bool
2022-03-10 10:44:48 +01:00
connLock sync.Mutex
conns map[string]*Conn
2024-05-15 19:15:00 -04:00
connWG *sync.WaitGroup
2022-03-10 10:44:48 +01:00
readWG sync.WaitGroup
errClose atomic.Value // error
2024-05-15 19:15:00 -04:00
readDoneCh chan struct{}
errRead atomic.Value // error
2024-06-05 16:10:03 -04:00
logger logging.LeveledLogger
2022-03-10 10:44:48 +01:00
}
// Accept waits for and returns the next connection to the listener.
func (l *listener) Accept() (net.Conn, error) {
select {
case c := <-l.acceptCh:
l.connWG.Add(1)
return c, nil
2024-05-15 19:15:00 -04:00
case <-l.readDoneCh:
err, _ := l.errRead.Load().(error)
return nil, err
2022-03-10 10:44:48 +01:00
case <-l.doneCh:
return nil, ErrClosedListener
}
}
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (l *listener) Close() error {
var err error
l.doneOnce.Do(func() {
l.accepting.Store(false)
close(l.doneCh)
l.connLock.Lock()
// Close unaccepted connections
2024-05-15 19:15:00 -04:00
lclose:
2022-03-10 10:44:48 +01:00
for {
select {
case c := <-l.acceptCh:
close(c.doneCh)
delete(l.conns, c.rAddr.String())
default:
2024-05-15 19:15:00 -04:00
break lclose
2022-03-10 10:44:48 +01:00
}
}
nConns := len(l.conns)
l.connLock.Unlock()
l.connWG.Done()
if nConns == 0 {
// Wait if this is the final connection
l.readWG.Wait()
if errClose, ok := l.errClose.Load().(error); ok {
err = errClose
}
} else {
err = nil
}
})
return err
}
// Addr returns the listener's network address.
func (l *listener) Addr() net.Addr {
return l.pConn.LocalAddr()
}
2024-06-05 16:10:03 -04:00
// BatchIOConfig indicates config to batch read/write packets,
// it will use ReadBatch/WriteBatch to improve throughput for UDP.
type BatchIOConfig struct {
Enable bool
// ReadBatchSize indicates the maximum number of packets to be read in one batch, a batch size less than 2 means
// disable read batch.
ReadBatchSize int
// WriteBatchSize indicates the maximum number of packets to be written in one batch
WriteBatchSize int
// WriteBatchInterval indicates the maximum interval to wait before writing packets in one batch
// small interval will reduce latency/jitter, but increase the io count.
WriteBatchInterval time.Duration
}
2022-03-10 10:44:48 +01:00
// ListenConfig stores options for listening to an address.
type ListenConfig struct {
// Backlog defines the maximum length of the queue of pending
// connections. It is equivalent of the backlog argument of
// POSIX listen function.
// If a connection request arrives when the queue is full,
// the request will be silently discarded, unlike TCP.
// Set zero to use default value 128 which is same as Linux default.
Backlog int
// AcceptFilter determines whether the new conn should be made for
// the incoming packet. If not set, any packet creates new conn.
AcceptFilter func([]byte) bool
2024-06-05 16:10:03 -04:00
// ReadBufferSize sets the size of the operating system's
// receive buffer associated with the listener.
ReadBufferSize int
// WriteBufferSize sets the size of the operating system's
// send buffer associated with the connection.
WriteBufferSize int
Batch BatchIOConfig
LoggerFactory logging.LoggerFactory
2022-03-10 10:44:48 +01:00
}
// Listen creates a new listener based on the ListenConfig.
func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (net.Listener, error) {
if lc.Backlog == 0 {
lc.Backlog = defaultListenBacklog
}
2024-06-05 16:10:03 -04:00
if lc.Batch.Enable && (lc.Batch.WriteBatchSize <= 0 || lc.Batch.WriteBatchInterval <= 0) {
return nil, ErrInvalidBatchConfig
}
2022-03-10 10:44:48 +01:00
conn, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
}
2024-06-05 16:10:03 -04:00
if lc.ReadBufferSize > 0 {
_ = conn.SetReadBuffer(lc.ReadBufferSize)
}
if lc.WriteBufferSize > 0 {
_ = conn.SetWriteBuffer(lc.WriteBufferSize)
}
loggerFactory := lc.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
}
logger := loggerFactory.NewLogger("transport")
2022-03-10 10:44:48 +01:00
l := &listener{
pConn: conn,
acceptCh: make(chan *Conn, lc.Backlog),
conns: make(map[string]*Conn),
doneCh: make(chan struct{}),
acceptFilter: lc.AcceptFilter,
2024-06-05 16:10:03 -04:00
connWG: &sync.WaitGroup{},
readDoneCh: make(chan struct{}),
logger: logger,
}
if lc.Batch.Enable {
l.pConn = NewBatchConn(conn, lc.Batch.WriteBatchSize, lc.Batch.WriteBatchInterval)
l.readBatchSize = lc.Batch.ReadBatchSize
2022-03-10 10:44:48 +01:00
}
l.accepting.Store(true)
l.connWG.Add(1)
l.readWG.Add(2) // wait readLoop and Close execution routine
go l.readLoop()
go func() {
l.connWG.Wait()
if err := l.pConn.Close(); err != nil {
l.errClose.Store(err)
}
l.readWG.Done()
}()
return l, nil
}
// Listen creates a new listener using default ListenConfig.
func Listen(network string, laddr *net.UDPAddr) (net.Listener, error) {
return (&ListenConfig{}).Listen(network, laddr)
}
// readLoop has to tasks:
2024-05-15 19:15:00 -04:00
// 1. Dispatching incoming packets to the correct Conn.
// It can therefore not be ended until all Conns are closed.
// 2. Creating a new Conn when receiving from a new remote.
2022-03-10 10:44:48 +01:00
func (l *listener) readLoop() {
defer l.readWG.Done()
2024-05-15 19:15:00 -04:00
defer close(l.readDoneCh)
2024-06-05 16:10:03 -04:00
if br, ok := l.pConn.(BatchReader); ok && l.readBatchSize > 1 {
l.readBatch(br)
} else {
l.read()
2024-05-15 19:15:00 -04:00
}
2024-06-05 16:10:03 -04:00
}
2022-03-10 10:44:48 +01:00
2024-06-05 16:10:03 -04:00
func (l *listener) readBatch(br BatchReader) {
msgs := make([]ipv4.Message, l.readBatchSize)
for i := range msgs {
msg := &msgs[i]
msg.Buffers = [][]byte{make([]byte, receiveMTU)}
msg.OOB = make([]byte, 40)
}
2022-03-10 10:44:48 +01:00
for {
2024-06-05 16:10:03 -04:00
n, err := br.ReadBatch(msgs, 0)
2022-03-10 10:44:48 +01:00
if err != nil {
2024-05-15 19:15:00 -04:00
l.errRead.Store(err)
2022-03-10 10:44:48 +01:00
return
}
2024-06-05 16:10:03 -04:00
for i := 0; i < n; i++ {
l.dispatchMsg(msgs[i].Addr, msgs[i].Buffers[0][:msgs[i].N])
}
}
}
func (l *listener) read() {
buf := make([]byte, receiveMTU)
for {
n, raddr, err := l.pConn.ReadFrom(buf)
2022-03-10 10:44:48 +01:00
if err != nil {
2024-06-05 16:10:03 -04:00
l.errRead.Store(err)
l.logger.Tracef("error reading from connection err=%v", err)
return
2022-03-10 10:44:48 +01:00
}
2024-06-05 16:10:03 -04:00
l.dispatchMsg(raddr, buf[:n])
}
}
func (l *listener) dispatchMsg(addr net.Addr, buf []byte) {
conn, ok, err := l.getConn(addr, buf)
if err != nil {
return
}
if ok {
_, err := conn.buffer.Write(buf)
if err != nil {
l.logger.Tracef("error dispatching message addr=%v err=%v", addr, err)
2022-03-10 10:44:48 +01:00
}
}
}
func (l *listener) getConn(raddr net.Addr, buf []byte) (*Conn, bool, error) {
l.connLock.Lock()
defer l.connLock.Unlock()
conn, ok := l.conns[raddr.String()]
if !ok {
2024-05-15 19:15:00 -04:00
if isAccepting, ok := l.accepting.Load().(bool); !isAccepting || !ok {
2022-03-10 10:44:48 +01:00
return nil, false, ErrClosedListener
}
if l.acceptFilter != nil {
if !l.acceptFilter(buf) {
return nil, false, nil
}
}
conn = l.newConn(raddr)
select {
case l.acceptCh <- conn:
l.conns[raddr.String()] = conn
default:
return nil, false, ErrListenQueueExceeded
}
}
return conn, true, nil
}
// Conn augments a connection-oriented connection over a UDP PacketConn
type Conn struct {
listener *listener
rAddr net.Addr
buffer *packetio.Buffer
doneCh chan struct{}
doneOnce sync.Once
writeDeadline *deadline.Deadline
}
func (l *listener) newConn(rAddr net.Addr) *Conn {
return &Conn{
listener: l,
rAddr: rAddr,
buffer: packetio.NewBuffer(),
doneCh: make(chan struct{}),
writeDeadline: deadline.New(),
}
}
// Read reads from c into p
func (c *Conn) Read(p []byte) (int, error) {
return c.buffer.Read(p)
}
// Write writes len(p) bytes from p to the DTLS connection
func (c *Conn) Write(p []byte) (n int, err error) {
select {
case <-c.writeDeadline.Done():
return 0, context.DeadlineExceeded
default:
}
return c.listener.pConn.WriteTo(p, c.rAddr)
}
// Close closes the conn and releases any Read calls
func (c *Conn) Close() error {
var err error
c.doneOnce.Do(func() {
c.listener.connWG.Done()
close(c.doneCh)
c.listener.connLock.Lock()
delete(c.listener.conns, c.rAddr.String())
nConns := len(c.listener.conns)
c.listener.connLock.Unlock()
2024-05-15 19:15:00 -04:00
if isAccepting, ok := c.listener.accepting.Load().(bool); nConns == 0 && !isAccepting && ok {
2022-03-10 10:44:48 +01:00
// Wait if this is the final connection
c.listener.readWG.Wait()
if errClose, ok := c.listener.errClose.Load().(error); ok {
err = errClose
}
} else {
err = nil
}
2024-05-15 19:15:00 -04:00
if errBuf := c.buffer.Close(); errBuf != nil && err == nil {
err = errBuf
}
2022-03-10 10:44:48 +01:00
})
return err
}
// LocalAddr implements net.Conn.LocalAddr
func (c *Conn) LocalAddr() net.Addr {
return c.listener.pConn.LocalAddr()
}
// RemoteAddr implements net.Conn.RemoteAddr
func (c *Conn) RemoteAddr() net.Addr {
return c.rAddr
}
// SetDeadline implements net.Conn.SetDeadline
func (c *Conn) SetDeadline(t time.Time) error {
c.writeDeadline.Set(t)
return c.SetReadDeadline(t)
}
// SetReadDeadline implements net.Conn.SetDeadline
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.buffer.SetReadDeadline(t)
}
// SetWriteDeadline implements net.Conn.SetDeadline
func (c *Conn) SetWriteDeadline(t time.Time) error {
c.writeDeadline.Set(t)
// Write deadline of underlying connection should not be changed
// since the connection can be shared.
return nil
}