241 lines
4.5 KiB
Go
241 lines
4.5 KiB
Go
|
package ice
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/pion/logging"
|
||
|
)
|
||
|
|
||
|
type tcpPacketConn struct {
|
||
|
params *tcpPacketParams
|
||
|
|
||
|
// conns is a map of net.Conns indexed by remote net.Addr.String()
|
||
|
conns map[string]net.Conn
|
||
|
|
||
|
recvChan chan streamingPacket
|
||
|
|
||
|
mu sync.Mutex
|
||
|
wg sync.WaitGroup
|
||
|
closedChan chan struct{}
|
||
|
closeOnce sync.Once
|
||
|
}
|
||
|
|
||
|
type streamingPacket struct {
|
||
|
Data []byte
|
||
|
RAddr net.Addr
|
||
|
Err error
|
||
|
}
|
||
|
|
||
|
type tcpPacketParams struct {
|
||
|
ReadBuffer int
|
||
|
LocalAddr net.Addr
|
||
|
Logger logging.LeveledLogger
|
||
|
}
|
||
|
|
||
|
func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
|
||
|
p := &tcpPacketConn{
|
||
|
params: ¶ms,
|
||
|
|
||
|
conns: map[string]net.Conn{},
|
||
|
|
||
|
recvChan: make(chan streamingPacket, params.ReadBuffer),
|
||
|
closedChan: make(chan struct{}),
|
||
|
}
|
||
|
|
||
|
return p
|
||
|
}
|
||
|
|
||
|
func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error {
|
||
|
t.params.Logger.Infof("AddConn: %s %s", conn.RemoteAddr().Network(), conn.RemoteAddr())
|
||
|
|
||
|
t.mu.Lock()
|
||
|
defer t.mu.Unlock()
|
||
|
|
||
|
select {
|
||
|
case <-t.closedChan:
|
||
|
return io.ErrClosedPipe
|
||
|
default:
|
||
|
}
|
||
|
|
||
|
if _, ok := t.conns[conn.RemoteAddr().String()]; ok {
|
||
|
return fmt.Errorf("%w: %s", errConnectionAddrAlreadyExist, conn.RemoteAddr().String())
|
||
|
}
|
||
|
|
||
|
t.conns[conn.RemoteAddr().String()] = conn
|
||
|
|
||
|
t.wg.Add(1)
|
||
|
go func() {
|
||
|
if firstPacketData != nil {
|
||
|
t.recvChan <- streamingPacket{firstPacketData, conn.RemoteAddr(), nil}
|
||
|
}
|
||
|
defer t.wg.Done()
|
||
|
t.startReading(conn)
|
||
|
}()
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (t *tcpPacketConn) startReading(conn net.Conn) {
|
||
|
buf := make([]byte, receiveMTU)
|
||
|
|
||
|
for {
|
||
|
n, err := readStreamingPacket(conn, buf)
|
||
|
// t.params.Logger.Infof("readStreamingPacket read %d bytes", n)
|
||
|
if err != nil {
|
||
|
t.params.Logger.Infof("%w: %s\n", errReadingStreamingPacket, err)
|
||
|
t.handleRecv(streamingPacket{nil, conn.RemoteAddr(), err})
|
||
|
t.removeConn(conn)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
data := make([]byte, n)
|
||
|
copy(data, buf[:n])
|
||
|
|
||
|
// t.params.Logger.Infof("Writing read streaming packet to recvChan: %d bytes", len(data))
|
||
|
t.handleRecv(streamingPacket{data, conn.RemoteAddr(), nil})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (t *tcpPacketConn) handleRecv(pkt streamingPacket) {
|
||
|
t.mu.Lock()
|
||
|
|
||
|
recvChan := t.recvChan
|
||
|
if t.isClosed() {
|
||
|
recvChan = nil
|
||
|
}
|
||
|
|
||
|
t.mu.Unlock()
|
||
|
|
||
|
select {
|
||
|
case recvChan <- pkt:
|
||
|
case <-t.closedChan:
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (t *tcpPacketConn) isClosed() bool {
|
||
|
select {
|
||
|
case <-t.closedChan:
|
||
|
return true
|
||
|
default:
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// WriteTo is for passive and s-o candidates.
|
||
|
func (t *tcpPacketConn) ReadFrom(b []byte) (n int, raddr net.Addr, err error) {
|
||
|
pkt, ok := <-t.recvChan
|
||
|
|
||
|
if !ok {
|
||
|
return 0, nil, io.ErrClosedPipe
|
||
|
}
|
||
|
|
||
|
if pkt.Err != nil {
|
||
|
return 0, pkt.RAddr, pkt.Err
|
||
|
}
|
||
|
|
||
|
if cap(b) < len(pkt.Data) {
|
||
|
return 0, pkt.RAddr, io.ErrShortBuffer
|
||
|
}
|
||
|
|
||
|
n = len(pkt.Data)
|
||
|
copy(b, pkt.Data[:n])
|
||
|
return n, pkt.RAddr, err
|
||
|
}
|
||
|
|
||
|
// WriteTo is for active and s-o candidates.
|
||
|
func (t *tcpPacketConn) WriteTo(buf []byte, raddr net.Addr) (n int, err error) {
|
||
|
t.mu.Lock()
|
||
|
defer t.mu.Unlock()
|
||
|
|
||
|
conn, ok := t.conns[raddr.String()]
|
||
|
if !ok {
|
||
|
return 0, io.ErrClosedPipe
|
||
|
// conn, err := net.DialTCP(tcp, nil, raddr.(*net.TCPAddr))
|
||
|
|
||
|
// if err != nil {
|
||
|
// t.params.Logger.Tracef("DialTCP error: %s", err)
|
||
|
// return 0, err
|
||
|
// }
|
||
|
|
||
|
// go t.startReading(conn)
|
||
|
// t.conns[raddr.String()] = conn
|
||
|
}
|
||
|
|
||
|
n, err = writeStreamingPacket(conn, buf)
|
||
|
if err != nil {
|
||
|
t.params.Logger.Tracef("%w %s\n", errWriting, raddr)
|
||
|
return n, err
|
||
|
}
|
||
|
|
||
|
return n, err
|
||
|
}
|
||
|
|
||
|
func (t *tcpPacketConn) closeAndLogError(closer io.Closer) {
|
||
|
err := closer.Close()
|
||
|
if err != nil {
|
||
|
t.params.Logger.Warnf("%w: %s", errClosingConnection, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (t *tcpPacketConn) removeConn(conn net.Conn) {
|
||
|
t.mu.Lock()
|
||
|
defer t.mu.Unlock()
|
||
|
|
||
|
t.closeAndLogError(conn)
|
||
|
|
||
|
delete(t.conns, conn.RemoteAddr().String())
|
||
|
}
|
||
|
|
||
|
func (t *tcpPacketConn) Close() error {
|
||
|
t.mu.Lock()
|
||
|
|
||
|
var shouldCloseRecvChan bool
|
||
|
t.closeOnce.Do(func() {
|
||
|
close(t.closedChan)
|
||
|
shouldCloseRecvChan = true
|
||
|
})
|
||
|
|
||
|
for _, conn := range t.conns {
|
||
|
t.closeAndLogError(conn)
|
||
|
delete(t.conns, conn.RemoteAddr().String())
|
||
|
}
|
||
|
|
||
|
t.mu.Unlock()
|
||
|
|
||
|
t.wg.Wait()
|
||
|
|
||
|
if shouldCloseRecvChan {
|
||
|
close(t.recvChan)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (t *tcpPacketConn) LocalAddr() net.Addr {
|
||
|
return t.params.LocalAddr
|
||
|
}
|
||
|
|
||
|
func (t *tcpPacketConn) SetDeadline(tm time.Time) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (t *tcpPacketConn) SetReadDeadline(tm time.Time) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (t *tcpPacketConn) SetWriteDeadline(tm time.Time) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (t *tcpPacketConn) CloseChannel() <-chan struct{} {
|
||
|
return t.closedChan
|
||
|
}
|
||
|
|
||
|
func (t *tcpPacketConn) String() string {
|
||
|
return fmt.Sprintf("tcpPacketConn{LocalAddr: %s}", t.params.LocalAddr)
|
||
|
}
|