2022-04-06 11:48:16 +02:00

632 lines
14 KiB
Go

package stun
import (
"errors"
"fmt"
"io"
"log"
"net"
"runtime"
"sync"
"sync/atomic"
"time"
)
// Dial connects to the address on the named network and then
// initializes Client on that connection, returning error if any.
func Dial(network, address string) (*Client, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewClient(conn)
}
// ErrNoConnection means that ClientOptions.Connection is nil.
var ErrNoConnection = errors.New("no connection provided")
// ClientOption sets some client option.
type ClientOption func(c *Client)
// WithHandler sets client handler which is called if Agent emits the Event
// with TransactionID that is not currently registered by Client.
// Useful for handling Data indications from TURN server.
func WithHandler(h Handler) ClientOption {
return func(c *Client) {
c.handler = h
}
}
// WithRTO sets client RTO as defined in STUN RFC.
func WithRTO(rto time.Duration) ClientOption {
return func(c *Client) {
c.rto = int64(rto)
}
}
// WithClock sets Clock of client, the source of current time.
// Also clock is passed to default collector if set.
func WithClock(clock Clock) ClientOption {
return func(c *Client) {
c.clock = clock
}
}
// WithTimeoutRate sets RTO timer minimum resolution.
func WithTimeoutRate(d time.Duration) ClientOption {
return func(c *Client) {
c.rtoRate = d
}
}
// WithAgent sets client STUN agent.
//
// Defaults to agent implementation in current package,
// see agent.go.
func WithAgent(a ClientAgent) ClientOption {
return func(c *Client) {
c.a = a
}
}
// WithCollector rests client timeout collector, the implementation
// of ticker which calls function on each tick.
func WithCollector(coll Collector) ClientOption {
return func(c *Client) {
c.collector = coll
}
}
// WithNoConnClose prevents client from closing underlying connection when
// the Close() method is called.
var WithNoConnClose ClientOption = func(c *Client) {
c.closeConn = false
}
// WithNoRetransmit disables retransmissions and sets RTO to
// defaultMaxAttempts * defaultRTO which will be effectively time out
// if not set.
//
// Useful for TCP connections where transport handles RTO.
func WithNoRetransmit(c *Client) {
c.maxAttempts = 0
if c.rto == 0 {
c.rto = defaultMaxAttempts * int64(defaultRTO)
}
}
const (
defaultTimeoutRate = time.Millisecond * 5
defaultRTO = time.Millisecond * 300
defaultMaxAttempts = 7
)
// NewClient initializes new Client from provided options,
// starting internal goroutines and using default options fields
// if necessary. Call Close method after using Client to close conn and
// release resources.
//
// The conn will be closed on Close call. Use WithNoConnClose option to
// prevent that.
//
// Note that user should handle the protocol multiplexing, client does not
// provide any API for it, so if you need to read application data, wrap the
// connection with your (de-)multiplexer and pass the wrapper as conn.
func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
c := &Client{
close: make(chan struct{}),
c: conn,
clock: systemClock,
rto: int64(defaultRTO),
rtoRate: defaultTimeoutRate,
t: make(map[transactionID]*clientTransaction, 100),
maxAttempts: defaultMaxAttempts,
closeConn: true,
}
for _, o := range options {
o(c)
}
if c.c == nil {
return nil, ErrNoConnection
}
if c.a == nil {
c.a = NewAgent(nil)
}
if err := c.a.SetHandler(c.handleAgentCallback); err != nil {
return nil, err
}
if c.collector == nil {
c.collector = &tickerCollector{
close: make(chan struct{}),
clock: c.clock,
}
}
if err := c.collector.Start(c.rtoRate, func(t time.Time) {
closedOrPanic(c.a.Collect(t))
}); err != nil {
return nil, err
}
c.wg.Add(1)
go c.readUntilClosed()
runtime.SetFinalizer(c, clientFinalizer)
return c, nil
}
func clientFinalizer(c *Client) {
if c == nil {
return
}
err := c.Close()
if err == ErrClientClosed {
return
}
if err == nil {
log.Println("client: called finalizer on non-closed client") // nolint
return
}
log.Println("client: called finalizer on non-closed client:", err) // nolint
}
// Connection wraps Reader, Writer and Closer interfaces.
type Connection interface {
io.Reader
io.Writer
io.Closer
}
// ClientAgent is Agent implementation that is used by Client to
// process transactions.
type ClientAgent interface {
Process(*Message) error
Close() error
Start(id [TransactionIDSize]byte, deadline time.Time) error
Stop(id [TransactionIDSize]byte) error
Collect(time.Time) error
SetHandler(h Handler) error
}
// Client simulates "connection" to STUN server.
type Client struct {
rto int64 // time.Duration
a ClientAgent
c Connection
close chan struct{}
rtoRate time.Duration
maxAttempts int32
closed bool
closeConn bool // should call c.Close() while closing
wg sync.WaitGroup
clock Clock
handler Handler
collector Collector
t map[transactionID]*clientTransaction
// mux guards closed and t
mux sync.RWMutex
}
// clientTransaction represents transaction in progress.
// If transaction is succeed or failed, f will be called
// provided by event.
// Concurrent access is invalid.
type clientTransaction struct {
id transactionID
attempt int32
calls int32
h Handler
start time.Time
rto time.Duration
raw []byte
}
func (t *clientTransaction) handle(e Event) {
if atomic.AddInt32(&t.calls, 1) == 1 {
t.h(e)
}
}
var clientTransactionPool = &sync.Pool{
New: func() interface{} {
return &clientTransaction{
raw: make([]byte, 1500),
}
},
}
func acquireClientTransaction() *clientTransaction {
return clientTransactionPool.Get().(*clientTransaction)
}
func putClientTransaction(t *clientTransaction) {
t.raw = t.raw[:0]
t.start = time.Time{}
t.attempt = 0
t.id = transactionID{}
clientTransactionPool.Put(t)
}
func (t *clientTransaction) nextTimeout(now time.Time) time.Time {
return now.Add(time.Duration(t.attempt+1) * t.rto)
}
// start registers transaction.
//
// Could return ErrClientClosed, ErrTransactionExists.
func (c *Client) start(t *clientTransaction) error {
c.mux.Lock()
defer c.mux.Unlock()
if c.closed {
return ErrClientClosed
}
_, exists := c.t[t.id]
if exists {
return ErrTransactionExists
}
c.t[t.id] = t
return nil
}
// Clock abstracts the source of current time.
type Clock interface {
Now() time.Time
}
type systemClockService struct{}
func (systemClockService) Now() time.Time { return time.Now() }
var systemClock = systemClockService{}
// SetRTO sets current RTO value.
func (c *Client) SetRTO(rto time.Duration) {
atomic.StoreInt64(&c.rto, int64(rto))
}
// StopErr occurs when Client fails to stop transaction while
// processing error.
type StopErr struct {
Err error // value returned by Stop()
Cause error // error that caused Stop() call
}
func (e StopErr) Error() string {
return fmt.Sprintf("error while stopping due to %s: %s", sprintErr(e.Cause), sprintErr(e.Err))
}
// CloseErr indicates client close failure.
type CloseErr struct {
AgentErr error
ConnectionErr error
}
func sprintErr(err error) string {
if err == nil {
return "<nil>"
}
return err.Error()
}
func (c CloseErr) Error() string {
return fmt.Sprintf("failed to close: %s (connection), %s (agent)", sprintErr(c.ConnectionErr), sprintErr(c.AgentErr))
}
func (c *Client) readUntilClosed() {
defer c.wg.Done()
m := new(Message)
m.Raw = make([]byte, 1024)
for {
select {
case <-c.close:
return
default:
}
_, err := m.ReadFrom(c.c)
if err == nil {
if pErr := c.a.Process(m); pErr == ErrAgentClosed {
return
}
}
}
}
func closedOrPanic(err error) {
if err == nil || err == ErrAgentClosed {
return
}
panic(err) // nolint
}
type tickerCollector struct {
close chan struct{}
wg sync.WaitGroup
clock Clock
}
// Collector calls function f with constant rate.
//
// The simple Collector is ticker which calls function on each tick.
type Collector interface {
Start(rate time.Duration, f func(now time.Time)) error
Close() error
}
func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error {
t := time.NewTicker(rate)
a.wg.Add(1)
go func() {
defer a.wg.Done()
for {
select {
case <-a.close:
t.Stop()
return
case <-t.C:
f(a.clock.Now())
}
}
}()
return nil
}
func (a *tickerCollector) Close() error {
close(a.close)
a.wg.Wait()
return nil
}
// ErrClientClosed indicates that client is closed.
var ErrClientClosed = errors.New("client is closed")
// Close stops internal connection and agent, returning CloseErr on error.
func (c *Client) Close() error {
if err := c.checkInit(); err != nil {
return err
}
c.mux.Lock()
if c.closed {
c.mux.Unlock()
return ErrClientClosed
}
c.closed = true
c.mux.Unlock()
if closeErr := c.collector.Close(); closeErr != nil {
return closeErr
}
var connErr error
agentErr := c.a.Close()
if c.closeConn {
connErr = c.c.Close()
}
close(c.close)
c.wg.Wait()
if agentErr == nil && connErr == nil {
return nil
}
return CloseErr{
AgentErr: agentErr,
ConnectionErr: connErr,
}
}
// Indicate sends indication m to server. Shorthand to Start call
// with zero deadline and callback.
func (c *Client) Indicate(m *Message) error {
return c.Start(m, nil)
}
// callbackWaitHandler blocks on wait() call until callback is called.
type callbackWaitHandler struct {
handler Handler
callback func(event Event)
cond *sync.Cond
processed bool
}
func (s *callbackWaitHandler) HandleEvent(e Event) {
s.cond.L.Lock()
if s.callback == nil {
panic("s.callback is nil") // nolint
}
s.callback(e)
s.processed = true
s.cond.Broadcast()
s.cond.L.Unlock()
}
func (s *callbackWaitHandler) wait() {
s.cond.L.Lock()
for !s.processed {
s.cond.Wait()
}
s.processed = false
s.callback = nil
s.cond.L.Unlock()
}
func (s *callbackWaitHandler) setCallback(f func(event Event)) {
if f == nil {
panic("f is nil") // nolint
}
s.cond.L.Lock()
s.callback = f
if s.handler == nil {
s.handler = s.HandleEvent
}
s.cond.L.Unlock()
}
var callbackWaitHandlerPool = sync.Pool{
New: func() interface{} {
return &callbackWaitHandler{
cond: sync.NewCond(new(sync.Mutex)),
}
},
}
// ErrClientNotInitialized means that client connection or agent is nil.
var ErrClientNotInitialized = errors.New("client not initialized")
func (c *Client) checkInit() error {
if c == nil || c.c == nil || c.a == nil || c.close == nil {
return ErrClientNotInitialized
}
return nil
}
// Do is Start wrapper that waits until callback is called. If no callback
// provided, Indicate is called instead.
//
// Do has cpu overhead due to blocking, see BenchmarkClient_Do.
// Use Start method for less overhead.
func (c *Client) Do(m *Message, f func(Event)) error {
if err := c.checkInit(); err != nil {
return err
}
if f == nil {
return c.Indicate(m)
}
h := callbackWaitHandlerPool.Get().(*callbackWaitHandler)
h.setCallback(f)
defer func() {
callbackWaitHandlerPool.Put(h)
}()
if err := c.Start(m, h.handler); err != nil {
return err
}
h.wait()
return nil
}
func (c *Client) delete(id transactionID) {
c.mux.Lock()
if c.t != nil {
delete(c.t, id)
}
c.mux.Unlock()
}
type buffer struct {
buf []byte
}
var bufferPool = &sync.Pool{
New: func() interface{} {
return &buffer{buf: make([]byte, 2048)}
},
}
func (c *Client) handleAgentCallback(e Event) {
c.mux.Lock()
if c.closed {
c.mux.Unlock()
return
}
t, found := c.t[e.TransactionID]
if found {
delete(c.t, t.id)
}
c.mux.Unlock()
if !found {
if c.handler != nil && e.Error != ErrTransactionStopped {
c.handler(e)
}
// Ignoring.
return
}
if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil {
// Transaction completed.
t.handle(e)
putClientTransaction(t)
return
}
// Doing re-transmission.
t.attempt++
b := bufferPool.Get().(*buffer)
b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)]
defer bufferPool.Put(b)
var (
now = c.clock.Now()
timeOut = t.nextTimeout(now)
id = t.id
)
// Starting client transaction.
if startErr := c.start(t); startErr != nil {
c.delete(id)
e.Error = startErr
t.handle(e)
putClientTransaction(t)
return
}
// Starting agent transaction.
if startErr := c.a.Start(id, timeOut); startErr != nil {
c.delete(id)
e.Error = startErr
t.handle(e)
putClientTransaction(t)
return
}
// Writing message to connection again.
_, writeErr := c.c.Write(b.buf)
if writeErr != nil {
c.delete(id)
e.Error = writeErr
// Stopping agent transaction instead of waiting until it's deadline.
// This will call handleAgentCallback with "ErrTransactionStopped" error
// which will be ignored.
if stopErr := c.a.Stop(id); stopErr != nil {
// Failed to stop agent transaction. Wrapping the error in StopError.
e.Error = StopErr{
Err: stopErr,
Cause: writeErr,
}
}
t.handle(e)
putClientTransaction(t)
return
}
}
// Start starts transaction (if h set) and writes message to server, handler
// is called asynchronously.
func (c *Client) Start(m *Message, h Handler) error {
if err := c.checkInit(); err != nil {
return err
}
c.mux.RLock()
closed := c.closed
c.mux.RUnlock()
if closed {
return ErrClientClosed
}
if h != nil {
// Starting transaction only if h is set. Useful for indications.
t := acquireClientTransaction()
t.id = m.TransactionID
t.start = c.clock.Now()
t.h = h
t.rto = time.Duration(atomic.LoadInt64(&c.rto))
t.attempt = 0
t.raw = append(t.raw[:0], m.Raw...)
t.calls = 0
d := t.nextTimeout(t.start)
if err := c.start(t); err != nil {
return err
}
if err := c.a.Start(m.TransactionID, d); err != nil {
return err
}
}
_, err := m.WriteTo(c.c)
if err != nil && h != nil {
c.delete(m.TransactionID)
// Stopping transaction instead of waiting until deadline.
if stopErr := c.a.Stop(m.TransactionID); stopErr != nil {
return StopErr{
Err: stopErr,
Cause: err,
}
}
}
return err
}