392 lines
11 KiB
Go
392 lines
11 KiB
Go
|
// Package datachannel implements WebRTC Data Channels
|
||
|
package datachannel
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
|
||
|
"github.com/pion/logging"
|
||
|
"github.com/pion/sctp"
|
||
|
)
|
||
|
|
||
|
const receiveMTU = 8192
|
||
|
|
||
|
// Reader is an extended io.Reader
|
||
|
// that also returns if the message is text.
|
||
|
type Reader interface {
|
||
|
ReadDataChannel([]byte) (int, bool, error)
|
||
|
}
|
||
|
|
||
|
// Writer is an extended io.Writer
|
||
|
// that also allows indicating if a message is text.
|
||
|
type Writer interface {
|
||
|
WriteDataChannel([]byte, bool) (int, error)
|
||
|
}
|
||
|
|
||
|
// ReadWriteCloser is an extended io.ReadWriteCloser
|
||
|
// that also implements our Reader and Writer.
|
||
|
type ReadWriteCloser interface {
|
||
|
io.Reader
|
||
|
io.Writer
|
||
|
Reader
|
||
|
Writer
|
||
|
io.Closer
|
||
|
}
|
||
|
|
||
|
// DataChannel represents a data channel
|
||
|
type DataChannel struct {
|
||
|
Config
|
||
|
|
||
|
// stats
|
||
|
messagesSent uint32
|
||
|
messagesReceived uint32
|
||
|
bytesSent uint64
|
||
|
bytesReceived uint64
|
||
|
|
||
|
mu sync.Mutex
|
||
|
onOpenCompleteHandler func()
|
||
|
openCompleteHandlerOnce sync.Once
|
||
|
|
||
|
stream *sctp.Stream
|
||
|
log logging.LeveledLogger
|
||
|
}
|
||
|
|
||
|
// Config is used to configure the data channel.
|
||
|
type Config struct {
|
||
|
ChannelType ChannelType
|
||
|
Negotiated bool
|
||
|
Priority uint16
|
||
|
ReliabilityParameter uint32
|
||
|
Label string
|
||
|
Protocol string
|
||
|
LoggerFactory logging.LoggerFactory
|
||
|
}
|
||
|
|
||
|
func newDataChannel(stream *sctp.Stream, config *Config) (*DataChannel, error) {
|
||
|
return &DataChannel{
|
||
|
Config: *config,
|
||
|
stream: stream,
|
||
|
log: config.LoggerFactory.NewLogger("datachannel"),
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
// Dial opens a data channels over SCTP
|
||
|
func Dial(a *sctp.Association, id uint16, config *Config) (*DataChannel, error) {
|
||
|
stream, err := a.OpenStream(id, sctp.PayloadTypeWebRTCBinary)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
dc, err := Client(stream, config)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return dc, nil
|
||
|
}
|
||
|
|
||
|
// Client opens a data channel over an SCTP stream
|
||
|
func Client(stream *sctp.Stream, config *Config) (*DataChannel, error) {
|
||
|
msg := &channelOpen{
|
||
|
ChannelType: config.ChannelType,
|
||
|
Priority: config.Priority,
|
||
|
ReliabilityParameter: config.ReliabilityParameter,
|
||
|
|
||
|
Label: []byte(config.Label),
|
||
|
Protocol: []byte(config.Protocol),
|
||
|
}
|
||
|
|
||
|
if !config.Negotiated {
|
||
|
rawMsg, err := msg.Marshal()
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to marshal ChannelOpen %w", err)
|
||
|
}
|
||
|
|
||
|
if _, err = stream.WriteSCTP(rawMsg, sctp.PayloadTypeWebRTCDCEP); err != nil {
|
||
|
return nil, fmt.Errorf("failed to send ChannelOpen %w", err)
|
||
|
}
|
||
|
}
|
||
|
return newDataChannel(stream, config)
|
||
|
}
|
||
|
|
||
|
// Accept is used to accept incoming data channels over SCTP
|
||
|
func Accept(a *sctp.Association, config *Config, existingChannels ...*DataChannel) (*DataChannel, error) {
|
||
|
stream, err := a.AcceptStream()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
for _, ch := range existingChannels {
|
||
|
if ch.StreamIdentifier() == stream.StreamIdentifier() {
|
||
|
ch.stream.SetDefaultPayloadType(sctp.PayloadTypeWebRTCBinary)
|
||
|
return ch, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
stream.SetDefaultPayloadType(sctp.PayloadTypeWebRTCBinary)
|
||
|
|
||
|
dc, err := Server(stream, config)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return dc, nil
|
||
|
}
|
||
|
|
||
|
// Server accepts a data channel over an SCTP stream
|
||
|
func Server(stream *sctp.Stream, config *Config) (*DataChannel, error) {
|
||
|
buffer := make([]byte, receiveMTU)
|
||
|
n, ppi, err := stream.ReadSCTP(buffer)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if ppi != sctp.PayloadTypeWebRTCDCEP {
|
||
|
return nil, fmt.Errorf("%w %s", ErrInvalidPayloadProtocolIdentifier, ppi)
|
||
|
}
|
||
|
|
||
|
openMsg, err := parseExpectDataChannelOpen(buffer[:n])
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to parse DataChannelOpen packet %w", err)
|
||
|
}
|
||
|
|
||
|
config.ChannelType = openMsg.ChannelType
|
||
|
config.Priority = openMsg.Priority
|
||
|
config.ReliabilityParameter = openMsg.ReliabilityParameter
|
||
|
config.Label = string(openMsg.Label)
|
||
|
config.Protocol = string(openMsg.Protocol)
|
||
|
|
||
|
dataChannel, err := newDataChannel(stream, config)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
err = dataChannel.writeDataChannelAck()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
err = dataChannel.commitReliabilityParams()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return dataChannel, nil
|
||
|
}
|
||
|
|
||
|
// Read reads a packet of len(p) bytes as binary data
|
||
|
func (c *DataChannel) Read(p []byte) (int, error) {
|
||
|
n, _, err := c.ReadDataChannel(p)
|
||
|
return n, err
|
||
|
}
|
||
|
|
||
|
// ReadDataChannel reads a packet of len(p) bytes
|
||
|
func (c *DataChannel) ReadDataChannel(p []byte) (int, bool, error) {
|
||
|
for {
|
||
|
n, ppi, err := c.stream.ReadSCTP(p)
|
||
|
if err == io.EOF {
|
||
|
// When the peer sees that an incoming stream was
|
||
|
// reset, it also resets its corresponding outgoing stream.
|
||
|
if closeErr := c.stream.Close(); closeErr != nil {
|
||
|
return 0, false, closeErr
|
||
|
}
|
||
|
}
|
||
|
if err != nil {
|
||
|
return 0, false, err
|
||
|
}
|
||
|
|
||
|
if ppi == sctp.PayloadTypeWebRTCDCEP {
|
||
|
if err = c.handleDCEP(p[:n]); err != nil {
|
||
|
c.log.Errorf("Failed to handle DCEP: %s", err.Error())
|
||
|
}
|
||
|
continue
|
||
|
} else if ppi == sctp.PayloadTypeWebRTCBinaryEmpty || ppi == sctp.PayloadTypeWebRTCStringEmpty {
|
||
|
n = 0
|
||
|
}
|
||
|
|
||
|
atomic.AddUint32(&c.messagesReceived, 1)
|
||
|
atomic.AddUint64(&c.bytesReceived, uint64(n))
|
||
|
|
||
|
isString := ppi == sctp.PayloadTypeWebRTCString || ppi == sctp.PayloadTypeWebRTCStringEmpty
|
||
|
return n, isString, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// MessagesSent returns the number of messages sent
|
||
|
func (c *DataChannel) MessagesSent() uint32 {
|
||
|
return atomic.LoadUint32(&c.messagesSent)
|
||
|
}
|
||
|
|
||
|
// MessagesReceived returns the number of messages received
|
||
|
func (c *DataChannel) MessagesReceived() uint32 {
|
||
|
return atomic.LoadUint32(&c.messagesReceived)
|
||
|
}
|
||
|
|
||
|
// OnOpen sets an event handler which is invoked when
|
||
|
// a DATA_CHANNEL_ACK message is received.
|
||
|
// The handler is called only on thefor the channel opened
|
||
|
// https://datatracker.ietf.org/doc/html/draft-ietf-rtcweb-data-protocol-09#section-5.2
|
||
|
func (c *DataChannel) OnOpen(f func()) {
|
||
|
c.mu.Lock()
|
||
|
c.openCompleteHandlerOnce = sync.Once{}
|
||
|
c.onOpenCompleteHandler = f
|
||
|
c.mu.Unlock()
|
||
|
}
|
||
|
|
||
|
func (c *DataChannel) onOpenComplete() {
|
||
|
c.mu.Lock()
|
||
|
hdlr := c.onOpenCompleteHandler
|
||
|
c.mu.Unlock()
|
||
|
|
||
|
if hdlr != nil {
|
||
|
go c.openCompleteHandlerOnce.Do(func() {
|
||
|
hdlr()
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// BytesSent returns the number of bytes sent
|
||
|
func (c *DataChannel) BytesSent() uint64 {
|
||
|
return atomic.LoadUint64(&c.bytesSent)
|
||
|
}
|
||
|
|
||
|
// BytesReceived returns the number of bytes received
|
||
|
func (c *DataChannel) BytesReceived() uint64 {
|
||
|
return atomic.LoadUint64(&c.bytesReceived)
|
||
|
}
|
||
|
|
||
|
// StreamIdentifier returns the Stream identifier associated to the stream.
|
||
|
func (c *DataChannel) StreamIdentifier() uint16 {
|
||
|
return c.stream.StreamIdentifier()
|
||
|
}
|
||
|
|
||
|
func (c *DataChannel) handleDCEP(data []byte) error {
|
||
|
msg, err := parse(data)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to parse DataChannel packet %w", err)
|
||
|
}
|
||
|
|
||
|
switch msg := msg.(type) {
|
||
|
case *channelAck:
|
||
|
c.log.Debug("Received DATA_CHANNEL_ACK")
|
||
|
if err = c.commitReliabilityParams(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
c.onOpenComplete()
|
||
|
default:
|
||
|
return fmt.Errorf("%w %v", ErrInvalidMessageType, msg)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Write writes len(p) bytes from p as binary data
|
||
|
func (c *DataChannel) Write(p []byte) (n int, err error) {
|
||
|
return c.WriteDataChannel(p, false)
|
||
|
}
|
||
|
|
||
|
// WriteDataChannel writes len(p) bytes from p
|
||
|
func (c *DataChannel) WriteDataChannel(p []byte, isString bool) (n int, err error) {
|
||
|
// https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-12#section-6.6
|
||
|
// SCTP does not support the sending of empty user messages. Therefore,
|
||
|
// if an empty message has to be sent, the appropriate PPID (WebRTC
|
||
|
// String Empty or WebRTC Binary Empty) is used and the SCTP user
|
||
|
// message of one zero byte is sent. When receiving an SCTP user
|
||
|
// message with one of these PPIDs, the receiver MUST ignore the SCTP
|
||
|
// user message and process it as an empty message.
|
||
|
var ppi sctp.PayloadProtocolIdentifier
|
||
|
switch {
|
||
|
case !isString && len(p) > 0:
|
||
|
ppi = sctp.PayloadTypeWebRTCBinary
|
||
|
case !isString && len(p) == 0:
|
||
|
ppi = sctp.PayloadTypeWebRTCBinaryEmpty
|
||
|
case isString && len(p) > 0:
|
||
|
ppi = sctp.PayloadTypeWebRTCString
|
||
|
case isString && len(p) == 0:
|
||
|
ppi = sctp.PayloadTypeWebRTCStringEmpty
|
||
|
}
|
||
|
|
||
|
atomic.AddUint32(&c.messagesSent, 1)
|
||
|
atomic.AddUint64(&c.bytesSent, uint64(len(p)))
|
||
|
|
||
|
if len(p) == 0 {
|
||
|
_, err := c.stream.WriteSCTP([]byte{0}, ppi)
|
||
|
return 0, err
|
||
|
}
|
||
|
return c.stream.WriteSCTP(p, ppi)
|
||
|
}
|
||
|
|
||
|
func (c *DataChannel) writeDataChannelAck() error {
|
||
|
ack := channelAck{}
|
||
|
ackMsg, err := ack.Marshal()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to marshal ChannelOpen ACK: %w", err)
|
||
|
}
|
||
|
|
||
|
if _, err = c.stream.WriteSCTP(ackMsg, sctp.PayloadTypeWebRTCDCEP); err != nil {
|
||
|
return fmt.Errorf("failed to send ChannelOpen ACK: %w", err)
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Close closes the DataChannel and the underlying SCTP stream.
|
||
|
func (c *DataChannel) Close() error {
|
||
|
// https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7
|
||
|
// Closing of a data channel MUST be signaled by resetting the
|
||
|
// corresponding outgoing streams [RFC6525]. This means that if one
|
||
|
// side decides to close the data channel, it resets the corresponding
|
||
|
// outgoing stream. When the peer sees that an incoming stream was
|
||
|
// reset, it also resets its corresponding outgoing stream. Once this
|
||
|
// is completed, the data channel is closed. Resetting a stream sets
|
||
|
// the Stream Sequence Numbers (SSNs) of the stream back to 'zero' with
|
||
|
// a corresponding notification to the application layer that the reset
|
||
|
// has been performed. Streams are available for reuse after a reset
|
||
|
// has been performed.
|
||
|
return c.stream.Close()
|
||
|
}
|
||
|
|
||
|
// BufferedAmount returns the number of bytes of data currently queued to be
|
||
|
// sent over this stream.
|
||
|
func (c *DataChannel) BufferedAmount() uint64 {
|
||
|
return c.stream.BufferedAmount()
|
||
|
}
|
||
|
|
||
|
// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
|
||
|
// data that is considered "low." Defaults to 0.
|
||
|
func (c *DataChannel) BufferedAmountLowThreshold() uint64 {
|
||
|
return c.stream.BufferedAmountLowThreshold()
|
||
|
}
|
||
|
|
||
|
// SetBufferedAmountLowThreshold is used to update the threshold.
|
||
|
// See BufferedAmountLowThreshold().
|
||
|
func (c *DataChannel) SetBufferedAmountLowThreshold(th uint64) {
|
||
|
c.stream.SetBufferedAmountLowThreshold(th)
|
||
|
}
|
||
|
|
||
|
// OnBufferedAmountLow sets the callback handler which would be called when the
|
||
|
// number of bytes of outgoing data buffered is lower than the threshold.
|
||
|
func (c *DataChannel) OnBufferedAmountLow(f func()) {
|
||
|
c.stream.OnBufferedAmountLow(f)
|
||
|
}
|
||
|
|
||
|
func (c *DataChannel) commitReliabilityParams() error {
|
||
|
switch c.Config.ChannelType {
|
||
|
case ChannelTypeReliable:
|
||
|
c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter)
|
||
|
case ChannelTypeReliableUnordered:
|
||
|
c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter)
|
||
|
case ChannelTypePartialReliableRexmit:
|
||
|
c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter)
|
||
|
case ChannelTypePartialReliableRexmitUnordered:
|
||
|
c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter)
|
||
|
case ChannelTypePartialReliableTimed:
|
||
|
c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter)
|
||
|
case ChannelTypePartialReliableTimedUnordered:
|
||
|
c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter)
|
||
|
default:
|
||
|
return fmt.Errorf("%w %v", ErrInvalidChannelType, c.Config.ChannelType)
|
||
|
}
|
||
|
return nil
|
||
|
}
|