status-go/vendor/github.com/pion/ice/v2/tcp_packet_conn.go

305 lines
5.7 KiB
Go

// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package ice
import (
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/pion/logging"
"github.com/pion/transport/v2/packetio"
)
type bufferedConn struct {
net.Conn
buf *packetio.Buffer
logger logging.LeveledLogger
closed int32
}
func newBufferedConn(conn net.Conn, bufSize int, logger logging.LeveledLogger) net.Conn {
buf := packetio.NewBuffer()
if bufSize > 0 {
buf.SetLimitSize(bufSize)
}
bc := &bufferedConn{
Conn: conn,
buf: buf,
logger: logger,
}
go bc.writeProcess()
return bc
}
func (bc *bufferedConn) Write(b []byte) (int, error) {
n, err := bc.buf.Write(b)
if err != nil {
return n, err
}
return n, nil
}
func (bc *bufferedConn) writeProcess() {
pktBuf := make([]byte, receiveMTU)
for atomic.LoadInt32(&bc.closed) == 0 {
n, err := bc.buf.Read(pktBuf)
if errors.Is(err, io.EOF) {
return
}
if err != nil {
bc.logger.Warnf("read buffer error: %s", err)
continue
}
if _, err := bc.Conn.Write(pktBuf[:n]); err != nil {
bc.logger.Warnf("write error: %s", err)
continue
}
}
}
func (bc *bufferedConn) Close() error {
atomic.StoreInt32(&bc.closed, 1)
_ = bc.buf.Close()
return bc.Conn.Close()
}
type tcpPacketConn struct {
params *tcpPacketParams
// conns is a map of net.Conns indexed by remote net.Addr.String()
conns map[string]net.Conn
recvChan chan streamingPacket
mu sync.Mutex
wg sync.WaitGroup
closedChan chan struct{}
closeOnce sync.Once
}
type streamingPacket struct {
Data []byte
RAddr net.Addr
Err error
}
type tcpPacketParams struct {
ReadBuffer int
LocalAddr net.Addr
Logger logging.LeveledLogger
WriteBuffer int
}
func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
p := &tcpPacketConn{
params: &params,
conns: map[string]net.Conn{},
recvChan: make(chan streamingPacket, params.ReadBuffer),
closedChan: make(chan struct{}),
}
return p
}
func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error {
t.params.Logger.Infof("AddConn: %s remote %s to local %s", conn.RemoteAddr().Network(), conn.RemoteAddr(), conn.LocalAddr())
t.mu.Lock()
defer t.mu.Unlock()
select {
case <-t.closedChan:
return io.ErrClosedPipe
default:
}
if _, ok := t.conns[conn.RemoteAddr().String()]; ok {
return fmt.Errorf("%w: %s", errConnectionAddrAlreadyExist, conn.RemoteAddr().String())
}
if t.params.WriteBuffer > 0 {
conn = newBufferedConn(conn, t.params.WriteBuffer, t.params.Logger)
}
t.conns[conn.RemoteAddr().String()] = conn
t.wg.Add(1)
go func() {
defer t.wg.Done()
if firstPacketData != nil {
select {
case <-t.closedChan:
// NOTE: recvChan can fill up and never drain in edge
// cases while closing a connection, which can cause the
// packetConn to never finish closing. Bail out early
// here to prevent that.
return
case t.recvChan <- streamingPacket{firstPacketData, conn.RemoteAddr(), nil}:
}
}
t.startReading(conn)
}()
return nil
}
func (t *tcpPacketConn) startReading(conn net.Conn) {
buf := make([]byte, receiveMTU)
for {
n, err := readStreamingPacket(conn, buf)
if err != nil {
t.params.Logger.Infof("%v: %s", errReadingStreamingPacket, err)
t.handleRecv(streamingPacket{nil, conn.RemoteAddr(), err})
t.removeConn(conn)
return
}
data := make([]byte, n)
copy(data, buf[:n])
t.handleRecv(streamingPacket{data, conn.RemoteAddr(), nil})
}
}
func (t *tcpPacketConn) handleRecv(pkt streamingPacket) {
t.mu.Lock()
recvChan := t.recvChan
if t.isClosed() {
recvChan = nil
}
t.mu.Unlock()
select {
case recvChan <- pkt:
case <-t.closedChan:
}
}
func (t *tcpPacketConn) isClosed() bool {
select {
case <-t.closedChan:
return true
default:
return false
}
}
// WriteTo is for passive and s-o candidates.
func (t *tcpPacketConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
pkt, ok := <-t.recvChan
if !ok {
return 0, nil, io.ErrClosedPipe
}
if pkt.Err != nil {
return 0, pkt.RAddr, pkt.Err
}
if cap(b) < len(pkt.Data) {
return 0, pkt.RAddr, io.ErrShortBuffer
}
n = len(pkt.Data)
copy(b, pkt.Data[:n])
return n, pkt.RAddr, err
}
// WriteTo is for active and s-o candidates.
func (t *tcpPacketConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
t.mu.Lock()
conn, ok := t.conns[rAddr.String()]
t.mu.Unlock()
if !ok {
return 0, io.ErrClosedPipe
}
n, err = writeStreamingPacket(conn, buf)
if err != nil {
t.params.Logger.Tracef("%w %s", errWriting, rAddr)
return n, err
}
return n, err
}
func (t *tcpPacketConn) closeAndLogError(closer io.Closer) {
err := closer.Close()
if err != nil {
t.params.Logger.Warnf("%v: %s", errClosingConnection, err)
}
}
func (t *tcpPacketConn) removeConn(conn net.Conn) {
t.mu.Lock()
defer t.mu.Unlock()
t.closeAndLogError(conn)
delete(t.conns, conn.RemoteAddr().String())
}
func (t *tcpPacketConn) Close() error {
t.mu.Lock()
var shouldCloseRecvChan bool
t.closeOnce.Do(func() {
close(t.closedChan)
shouldCloseRecvChan = true
})
for _, conn := range t.conns {
t.closeAndLogError(conn)
delete(t.conns, conn.RemoteAddr().String())
}
t.mu.Unlock()
t.wg.Wait()
if shouldCloseRecvChan {
close(t.recvChan)
}
return nil
}
func (t *tcpPacketConn) LocalAddr() net.Addr {
return t.params.LocalAddr
}
func (t *tcpPacketConn) SetDeadline(time.Time) error {
return nil
}
func (t *tcpPacketConn) SetReadDeadline(time.Time) error {
return nil
}
func (t *tcpPacketConn) SetWriteDeadline(time.Time) error {
return nil
}
func (t *tcpPacketConn) CloseChannel() <-chan struct{} {
return t.closedChan
}
func (t *tcpPacketConn) String() string {
return fmt.Sprintf("tcpPacketConn{LocalAddr: %s}", t.params.LocalAddr)
}