// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package stun import ( "crypto/tls" "errors" "fmt" "io" "log" "net" "runtime" "strconv" "sync" "sync/atomic" "time" "github.com/pion/dtls/v2" "github.com/pion/transport/v2" "github.com/pion/transport/v2/stdnet" ) // ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI var ErrUnsupportedURI = fmt.Errorf("invalid schema or transport") // Dial connects to the address on the named network and then // initializes Client on that connection, returning error if any. func Dial(network, address string) (*Client, error) { conn, err := net.Dial(network, address) if err != nil { return nil, err } return NewClient(conn) } // DialConfig is used to pass configuration to DialURI() type DialConfig struct { DTLSConfig dtls.Config TLSConfig tls.Config Net transport.Net } // DialURI connect to the STUN/TURN URI and then // initializes Client on that connection, returning error if any. func DialURI(uri *URI, cfg *DialConfig) (*Client, error) { var conn Connection var err error nw := cfg.Net if nw == nil { nw, err = stdnet.NewNet() if err != nil { return nil, fmt.Errorf("failed to create net: %w", err) } } addr := net.JoinHostPort(uri.Host, strconv.Itoa(uri.Port)) switch { case uri.Scheme == SchemeTypeSTUN: if conn, err = nw.Dial("udp", addr); err != nil { return nil, fmt.Errorf("failed to listen: %w", err) } case uri.Scheme == SchemeTypeTURN: network := "udp" //nolint:goconst if uri.Proto == ProtoTypeTCP { network = "tcp" //nolint:goconst } if conn, err = nw.Dial(network, addr); err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } case uri.Scheme == SchemeTypeTURNS && uri.Proto == ProtoTypeUDP: dtlsCfg := cfg.DTLSConfig // Copy dtlsCfg.ServerName = uri.Host udpConn, err := nw.Dial("udp", addr) if err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } if conn, err = dtls.Client(udpConn, &dtlsCfg); err != nil { return nil, fmt.Errorf("failed to connect to '%s': %w", addr, err) } case (uri.Scheme == SchemeTypeTURNS || uri.Scheme == SchemeTypeSTUNS) && uri.Proto == ProtoTypeTCP: tlsCfg := cfg.TLSConfig //nolint:govet tlsCfg.ServerName = uri.Host tcpConn, err := nw.Dial("tcp", addr) if err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } conn = tls.Client(tcpConn, &tlsCfg) default: return nil, ErrUnsupportedURI } return NewClient(conn) } // ErrNoConnection means that ClientOptions.Connection is nil. var ErrNoConnection = errors.New("no connection provided") // ClientOption sets some client option. type ClientOption func(c *Client) // WithHandler sets client handler which is called if Agent emits the Event // with TransactionID that is not currently registered by Client. // Useful for handling Data indications from TURN server. func WithHandler(h Handler) ClientOption { return func(c *Client) { c.handler = h } } // WithRTO sets client RTO as defined in STUN RFC. func WithRTO(rto time.Duration) ClientOption { return func(c *Client) { c.rto = int64(rto) } } // WithClock sets Clock of client, the source of current time. // Also clock is passed to default collector if set. func WithClock(clock Clock) ClientOption { return func(c *Client) { c.clock = clock } } // WithTimeoutRate sets RTO timer minimum resolution. func WithTimeoutRate(d time.Duration) ClientOption { return func(c *Client) { c.rtoRate = d } } // WithAgent sets client STUN agent. // // Defaults to agent implementation in current package, // see agent.go. func WithAgent(a ClientAgent) ClientOption { return func(c *Client) { c.a = a } } // WithCollector rests client timeout collector, the implementation // of ticker which calls function on each tick. func WithCollector(coll Collector) ClientOption { return func(c *Client) { c.collector = coll } } // WithNoConnClose prevents client from closing underlying connection when // the Close() method is called. func WithNoConnClose() ClientOption { return func(c *Client) { c.closeConn = false } } // WithNoRetransmit disables retransmissions and sets RTO to // defaultMaxAttempts * defaultRTO which will be effectively time out // if not set. // // Useful for TCP connections where transport handles RTO. func WithNoRetransmit(c *Client) { c.maxAttempts = 0 if c.rto == 0 { c.rto = defaultMaxAttempts * int64(defaultRTO) } } const ( defaultTimeoutRate = time.Millisecond * 5 defaultRTO = time.Millisecond * 300 defaultMaxAttempts = 7 ) // NewClient initializes new Client from provided options, // starting internal goroutines and using default options fields // if necessary. Call Close method after using Client to close conn and // release resources. // // The conn will be closed on Close call. Use WithNoConnClose option to // prevent that. // // Note that user should handle the protocol multiplexing, client does not // provide any API for it, so if you need to read application data, wrap the // connection with your (de-)multiplexer and pass the wrapper as conn. func NewClient(conn Connection, options ...ClientOption) (*Client, error) { c := &Client{ close: make(chan struct{}), c: conn, clock: systemClock(), rto: int64(defaultRTO), rtoRate: defaultTimeoutRate, t: make(map[transactionID]*clientTransaction, 100), maxAttempts: defaultMaxAttempts, closeConn: true, } for _, o := range options { o(c) } if c.c == nil { return nil, ErrNoConnection } if c.a == nil { c.a = NewAgent(nil) } if err := c.a.SetHandler(c.handleAgentCallback); err != nil { return nil, err } if c.collector == nil { c.collector = &tickerCollector{ close: make(chan struct{}), clock: c.clock, } } if err := c.collector.Start(c.rtoRate, func(t time.Time) { closedOrPanic(c.a.Collect(t)) }); err != nil { return nil, err } c.wg.Add(1) go c.readUntilClosed() runtime.SetFinalizer(c, clientFinalizer) return c, nil } func clientFinalizer(c *Client) { if c == nil { return } err := c.Close() if errors.Is(err, ErrClientClosed) { return } if err == nil { log.Println("client: called finalizer on non-closed client") // nolint return } log.Println("client: called finalizer on non-closed client:", err) // nolint } // Connection wraps Reader, Writer and Closer interfaces. type Connection interface { io.Reader io.Writer io.Closer } // ClientAgent is Agent implementation that is used by Client to // process transactions. type ClientAgent interface { Process(*Message) error Close() error Start(id [TransactionIDSize]byte, deadline time.Time) error Stop(id [TransactionIDSize]byte) error Collect(time.Time) error SetHandler(h Handler) error } // Client simulates "connection" to STUN server. type Client struct { rto int64 // time.Duration a ClientAgent c Connection close chan struct{} rtoRate time.Duration maxAttempts int32 closed bool closeConn bool // should call c.Close() while closing wg sync.WaitGroup clock Clock handler Handler collector Collector t map[transactionID]*clientTransaction // mux guards closed and t mux sync.RWMutex } // clientTransaction represents transaction in progress. // If transaction is succeed or failed, f will be called // provided by event. // Concurrent access is invalid. type clientTransaction struct { id transactionID attempt int32 calls int32 h Handler start time.Time rto time.Duration raw []byte } func (t *clientTransaction) handle(e Event) { if atomic.AddInt32(&t.calls, 1) == 1 { t.h(e) } } var clientTransactionPool = &sync.Pool{ //nolint:gochecknoglobals New: func() interface{} { return &clientTransaction{ raw: make([]byte, 1500), } }, } func acquireClientTransaction() *clientTransaction { return clientTransactionPool.Get().(*clientTransaction) //nolint:forcetypeassert } func putClientTransaction(t *clientTransaction) { t.raw = t.raw[:0] t.start = time.Time{} t.attempt = 0 t.id = transactionID{} clientTransactionPool.Put(t) } func (t *clientTransaction) nextTimeout(now time.Time) time.Time { return now.Add(time.Duration(t.attempt+1) * t.rto) } // start registers transaction. // // Could return ErrClientClosed, ErrTransactionExists. func (c *Client) start(t *clientTransaction) error { c.mux.Lock() defer c.mux.Unlock() if c.closed { return ErrClientClosed } _, exists := c.t[t.id] if exists { return ErrTransactionExists } c.t[t.id] = t return nil } // Clock abstracts the source of current time. type Clock interface { Now() time.Time } type systemClockService struct{} func (systemClockService) Now() time.Time { return time.Now() } func systemClock() systemClockService { return systemClockService{} } // SetRTO sets current RTO value. func (c *Client) SetRTO(rto time.Duration) { atomic.StoreInt64(&c.rto, int64(rto)) } // StopErr occurs when Client fails to stop transaction while // processing error. // //nolint:errname type StopErr struct { Err error // value returned by Stop() Cause error // error that caused Stop() call } func (e StopErr) Error() string { return fmt.Sprintf("error while stopping due to %s: %s", sprintErr(e.Cause), sprintErr(e.Err)) } // CloseErr indicates client close failure. // //nolint:errname type CloseErr struct { AgentErr error ConnectionErr error } func sprintErr(err error) string { if err == nil { return "" //nolint:goconst } return err.Error() } func (c CloseErr) Error() string { return fmt.Sprintf("failed to close: %s (connection), %s (agent)", sprintErr(c.ConnectionErr), sprintErr(c.AgentErr)) } func (c *Client) readUntilClosed() { defer c.wg.Done() m := new(Message) m.Raw = make([]byte, 1024) for { select { case <-c.close: return default: } _, err := m.ReadFrom(c.c) if err == nil { if pErr := c.a.Process(m); errors.Is(pErr, ErrAgentClosed) { return } } } } func closedOrPanic(err error) { if err == nil || errors.Is(err, ErrAgentClosed) { return } panic(err) //nolint } type tickerCollector struct { close chan struct{} wg sync.WaitGroup clock Clock } // Collector calls function f with constant rate. // // The simple Collector is ticker which calls function on each tick. type Collector interface { Start(rate time.Duration, f func(now time.Time)) error Close() error } func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error { t := time.NewTicker(rate) a.wg.Add(1) go func() { defer a.wg.Done() for { select { case <-a.close: t.Stop() return case <-t.C: f(a.clock.Now()) } } }() return nil } func (a *tickerCollector) Close() error { close(a.close) a.wg.Wait() return nil } // ErrClientClosed indicates that client is closed. var ErrClientClosed = errors.New("client is closed") // Close stops internal connection and agent, returning CloseErr on error. func (c *Client) Close() error { if err := c.checkInit(); err != nil { return err } c.mux.Lock() if c.closed { c.mux.Unlock() return ErrClientClosed } c.closed = true c.mux.Unlock() if closeErr := c.collector.Close(); closeErr != nil { return closeErr } var connErr error agentErr := c.a.Close() if c.closeConn { connErr = c.c.Close() } close(c.close) c.wg.Wait() if agentErr == nil && connErr == nil { return nil } return CloseErr{ AgentErr: agentErr, ConnectionErr: connErr, } } // Indicate sends indication m to server. Shorthand to Start call // with zero deadline and callback. func (c *Client) Indicate(m *Message) error { return c.Start(m, nil) } // callbackWaitHandler blocks on wait() call until callback is called. type callbackWaitHandler struct { handler Handler callback func(event Event) cond *sync.Cond processed bool } func (s *callbackWaitHandler) HandleEvent(e Event) { s.cond.L.Lock() if s.callback == nil { panic("s.callback is nil") //nolint } s.callback(e) s.processed = true s.cond.Broadcast() s.cond.L.Unlock() } func (s *callbackWaitHandler) wait() { s.cond.L.Lock() for !s.processed { s.cond.Wait() } s.processed = false s.callback = nil s.cond.L.Unlock() } func (s *callbackWaitHandler) setCallback(f func(event Event)) { if f == nil { panic("f is nil") //nolint } s.cond.L.Lock() s.callback = f if s.handler == nil { s.handler = s.HandleEvent } s.cond.L.Unlock() } var callbackWaitHandlerPool = sync.Pool{ //nolint:gochecknoglobals New: func() interface{} { return &callbackWaitHandler{ cond: sync.NewCond(new(sync.Mutex)), } }, } // ErrClientNotInitialized means that client connection or agent is nil. var ErrClientNotInitialized = errors.New("client not initialized") func (c *Client) checkInit() error { if c == nil || c.c == nil || c.a == nil || c.close == nil { return ErrClientNotInitialized } return nil } // Do is Start wrapper that waits until callback is called. If no callback // provided, Indicate is called instead. // // Do has cpu overhead due to blocking, see BenchmarkClient_Do. // Use Start method for less overhead. func (c *Client) Do(m *Message, f func(Event)) error { if err := c.checkInit(); err != nil { return err } if f == nil { return c.Indicate(m) } h := callbackWaitHandlerPool.Get().(*callbackWaitHandler) //nolint:forcetypeassert h.setCallback(f) defer func() { callbackWaitHandlerPool.Put(h) }() if err := c.Start(m, h.handler); err != nil { return err } h.wait() return nil } func (c *Client) delete(id transactionID) { c.mux.Lock() if c.t != nil { delete(c.t, id) } c.mux.Unlock() } type buffer struct { buf []byte } var bufferPool = &sync.Pool{ //nolint:gochecknoglobals New: func() interface{} { return &buffer{buf: make([]byte, 2048)} }, } func (c *Client) handleAgentCallback(e Event) { c.mux.Lock() if c.closed { c.mux.Unlock() return } t, found := c.t[e.TransactionID] if found { delete(c.t, t.id) } c.mux.Unlock() if !found { if c.handler != nil && !errors.Is(e.Error, ErrTransactionStopped) { c.handler(e) } // Ignoring. return } if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil { // Transaction completed. t.handle(e) putClientTransaction(t) return } // Doing re-transmission. t.attempt++ b := bufferPool.Get().(*buffer) //nolint:forcetypeassert b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)] defer bufferPool.Put(b) var ( now = c.clock.Now() timeOut = t.nextTimeout(now) id = t.id ) // Starting client transaction. if startErr := c.start(t); startErr != nil { c.delete(id) e.Error = startErr t.handle(e) putClientTransaction(t) return } // Starting agent transaction. if startErr := c.a.Start(id, timeOut); startErr != nil { c.delete(id) e.Error = startErr t.handle(e) putClientTransaction(t) return } // Writing message to connection again. _, writeErr := c.c.Write(b.buf) if writeErr != nil { c.delete(id) e.Error = writeErr // Stopping agent transaction instead of waiting until it's deadline. // This will call handleAgentCallback with "ErrTransactionStopped" error // which will be ignored. if stopErr := c.a.Stop(id); stopErr != nil { // Failed to stop agent transaction. Wrapping the error in StopError. e.Error = StopErr{ Err: stopErr, Cause: writeErr, } } t.handle(e) putClientTransaction(t) return } } // Start starts transaction (if h set) and writes message to server, handler // is called asynchronously. func (c *Client) Start(m *Message, h Handler) error { if err := c.checkInit(); err != nil { return err } c.mux.RLock() closed := c.closed c.mux.RUnlock() if closed { return ErrClientClosed } if h != nil { // Starting transaction only if h is set. Useful for indications. t := acquireClientTransaction() t.id = m.TransactionID t.start = c.clock.Now() t.h = h t.rto = time.Duration(atomic.LoadInt64(&c.rto)) t.attempt = 0 t.raw = append(t.raw[:0], m.Raw...) t.calls = 0 d := t.nextTimeout(t.start) if err := c.start(t); err != nil { return err } if err := c.a.Start(m.TransactionID, d); err != nil { return err } } _, err := m.WriteTo(c.c) if err != nil && h != nil { c.delete(m.TransactionID) // Stopping transaction instead of waiting until deadline. if stopErr := c.a.Stop(m.TransactionID); stopErr != nil { return StopErr{ Err: stopErr, Cause: err, } } } return err }