295 lines
7.1 KiB
Go
Raw Normal View History

2022-03-10 10:44:48 +01:00
package ice
import (
"encoding/binary"
"io"
"net"
"strings"
"sync"
"github.com/pion/logging"
"github.com/pion/stun"
)
// TCPMux is allows grouping multiple TCP net.Conns and using them like UDP
// net.PacketConns. The main implementation of this is TCPMuxDefault, and this
// interface exists to:
// 1. prevent SEGV panics when TCPMuxDefault is not initialized by using the
// invalidTCPMux implementation, and
// 2. allow mocking in tests.
type TCPMux interface {
io.Closer
GetConnByUfrag(ufrag string) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string)
}
// invalidTCPMux is an implementation of TCPMux that always returns ErrTCPMuxNotInitialized.
type invalidTCPMux struct{}
func newInvalidTCPMux() *invalidTCPMux {
return &invalidTCPMux{}
}
// Close implements TCPMux interface.
func (m *invalidTCPMux) Close() error {
return ErrTCPMuxNotInitialized
}
// GetConnByUfrag implements TCPMux interface.
func (m *invalidTCPMux) GetConnByUfrag(ufrag string) (net.PacketConn, error) {
return nil, ErrTCPMuxNotInitialized
}
// RemoveConnByUfrag implements TCPMux interface.
func (m *invalidTCPMux) RemoveConnByUfrag(ufrag string) {}
// TCPMuxDefault muxes TCP net.Conns into net.PacketConns and groups them by
// Ufrag. It is a default implementation of TCPMux interface.
type TCPMuxDefault struct {
params *TCPMuxParams
closed bool
// conns is a map of all tcpPacketConns indexed by ufrag
conns map[string]*tcpPacketConn
mu sync.Mutex
wg sync.WaitGroup
}
// TCPMuxParams are parameters for TCPMux.
type TCPMuxParams struct {
Listener net.Listener
Logger logging.LeveledLogger
ReadBufferSize int
}
// NewTCPMuxDefault creates a new instance of TCPMuxDefault.
func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
m := &TCPMuxDefault{
params: &params,
conns: map[string]*tcpPacketConn{},
}
m.wg.Add(1)
go func() {
defer m.wg.Done()
m.start()
}()
return m
}
func (m *TCPMuxDefault) start() {
m.params.Logger.Infof("Listening TCP on %s\n", m.params.Listener.Addr())
for {
conn, err := m.params.Listener.Accept()
if err != nil {
m.params.Logger.Infof("Error accepting connection: %s\n", err)
return
}
m.params.Logger.Debugf("Accepted connection from: %s to %s", conn.RemoteAddr(), conn.LocalAddr())
m.wg.Add(1)
go func() {
defer m.wg.Done()
m.handleConn(conn)
}()
}
}
// LocalAddr returns the listening address of this TCPMuxDefault.
func (m *TCPMuxDefault) LocalAddr() net.Addr {
return m.params.Listener.Addr()
}
// GetConnByUfrag retrieves an existing or creates a new net.PacketConn.
func (m *TCPMuxDefault) GetConnByUfrag(ufrag string) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.closed {
return nil, io.ErrClosedPipe
}
conn, ok := m.conns[ufrag]
if ok {
return conn, nil
// return nil, fmt.Errorf("duplicate ufrag %v", ufrag)
}
conn = m.createConn(ufrag, m.LocalAddr())
return conn, nil
}
func (m *TCPMuxDefault) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn {
conn := newTCPPacketConn(tcpPacketParams{
ReadBuffer: m.params.ReadBufferSize,
LocalAddr: localAddr,
Logger: m.params.Logger,
})
m.conns[ufrag] = conn
m.wg.Add(1)
go func() {
defer m.wg.Done()
<-conn.CloseChannel()
m.RemoveConnByUfrag(ufrag)
}()
return conn
}
func (m *TCPMuxDefault) closeAndLogError(closer io.Closer) {
err := closer.Close()
if err != nil {
m.params.Logger.Warnf("Error closing connection: %s", err)
}
}
func (m *TCPMuxDefault) handleConn(conn net.Conn) {
buf := make([]byte, receiveMTU)
n, err := readStreamingPacket(conn, buf)
if err != nil {
m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr().String(), err)
return
}
buf = buf[:n]
msg := &stun.Message{
Raw: make([]byte, len(buf)),
}
// Explicitly copy raw buffer so Message can own the memory.
copy(msg.Raw, buf)
if err = msg.Decode(); err != nil {
m.closeAndLogError(conn)
m.params.Logger.Warnf("Failed to handle decode ICE from %s to %s: %v\n", conn.RemoteAddr(), conn.LocalAddr(), err)
return
}
if m == nil || msg.Type.Method != stun.MethodBinding { // not a stun
m.closeAndLogError(conn)
m.params.Logger.Warnf("Not a STUN message from %s to %s\n", conn.RemoteAddr(), conn.LocalAddr())
return
}
for _, attr := range msg.Attributes {
m.params.Logger.Debugf("msg attr: %s\n", attr.String())
}
attr, err := msg.Get(stun.AttrUsername)
if err != nil {
m.closeAndLogError(conn)
m.params.Logger.Warnf("No Username attribute in STUN message from %s to %s\n", conn.RemoteAddr(), conn.LocalAddr())
return
}
ufrag := strings.Split(string(attr), ":")[0]
m.params.Logger.Debugf("Ufrag: %s\n", ufrag)
m.mu.Lock()
defer m.mu.Unlock()
packetConn, ok := m.conns[ufrag]
if !ok {
packetConn = m.createConn(ufrag, conn.LocalAddr())
}
if err := packetConn.AddConn(conn, buf); err != nil {
m.closeAndLogError(conn)
m.params.Logger.Warnf("Error adding conn to tcpPacketConn from %s to %s: %s\n", conn.RemoteAddr(), conn.LocalAddr(), err)
return
}
}
// Close closes the listener and waits for all goroutines to exit.
func (m *TCPMuxDefault) Close() error {
m.mu.Lock()
m.closed = true
for _, conn := range m.conns {
m.closeAndLogError(conn)
}
m.conns = map[string]*tcpPacketConn{}
err := m.params.Listener.Close()
m.mu.Unlock()
m.wg.Wait()
return err
}
// RemoveConnByUfrag closes and removes a net.PacketConn by Ufrag.
func (m *TCPMuxDefault) RemoveConnByUfrag(ufrag string) {
m.mu.Lock()
defer m.mu.Unlock()
if conn, ok := m.conns[ufrag]; ok {
m.closeAndLogError(conn)
delete(m.conns, ufrag)
}
}
const streamingPacketHeaderLen = 2
// readStreamingPacket reads 1 packet from stream
// read packet bytes https://tools.ietf.org/html/rfc4571#section-2
// 2-byte length header prepends each packet:
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// -----------------------------------------------------------------
// | LENGTH | RTP or RTCP packet ... |
// -----------------------------------------------------------------
func readStreamingPacket(conn net.Conn, buf []byte) (int, error) {
header := make([]byte, streamingPacketHeaderLen)
var bytesRead, n int
var err error
for bytesRead < streamingPacketHeaderLen {
if n, err = conn.Read(header[bytesRead:streamingPacketHeaderLen]); err != nil {
return 0, err
}
bytesRead += n
}
length := int(binary.BigEndian.Uint16(header))
if length > cap(buf) {
return length, io.ErrShortBuffer
}
bytesRead = 0
for bytesRead < length {
if n, err = conn.Read(buf[bytesRead:length]); err != nil {
return 0, err
}
bytesRead += n
}
return bytesRead, nil
}
func writeStreamingPacket(conn net.Conn, buf []byte) (int, error) {
bufferCopy := make([]byte, streamingPacketHeaderLen+len(buf))
binary.BigEndian.PutUint16(bufferCopy, uint16(len(buf)))
copy(bufferCopy[2:], buf)
n, err := conn.Write(bufferCopy)
if err != nil {
return 0, err
}
return n - streamingPacketHeaderLen, nil
}