// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package webrtc import ( "errors" "io" "math" "sync" "time" "github.com/pion/datachannel" "github.com/pion/logging" "github.com/pion/sctp" "github.com/pion/webrtc/v3/pkg/rtcerr" ) const sctpMaxChannels = uint16(65535) // SCTPTransport provides details about the SCTP transport. type SCTPTransport struct { lock sync.RWMutex dtlsTransport *DTLSTransport // State represents the current state of the SCTP transport. state SCTPTransportState // SCTPTransportState doesn't have an enum to distinguish between New/Connecting // so we need a dedicated field isStarted bool // MaxMessageSize represents the maximum size of data that can be passed to // DataChannel's send() method. maxMessageSize float64 // MaxChannels represents the maximum amount of DataChannel's that can // be used simultaneously. maxChannels *uint16 // OnStateChange func() onErrorHandler func(error) sctpAssociation *sctp.Association onDataChannelHandler func(*DataChannel) onDataChannelOpenedHandler func(*DataChannel) // DataChannels dataChannels []*DataChannel dataChannelsOpened uint32 dataChannelsRequested uint32 dataChannelsAccepted uint32 api *API log logging.LeveledLogger } // NewSCTPTransport creates a new SCTPTransport. // This constructor is part of the ORTC API. It is not // meant to be used together with the basic WebRTC API. func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport { res := &SCTPTransport{ dtlsTransport: dtls, state: SCTPTransportStateConnecting, api: api, log: api.settingEngine.LoggerFactory.NewLogger("ortc"), } res.updateMessageSize() res.updateMaxChannels() return res } // Transport returns the DTLSTransport instance the SCTPTransport is sending over. func (r *SCTPTransport) Transport() *DTLSTransport { r.lock.RLock() defer r.lock.RUnlock() return r.dtlsTransport } // GetCapabilities returns the SCTPCapabilities of the SCTPTransport. func (r *SCTPTransport) GetCapabilities() SCTPCapabilities { return SCTPCapabilities{ MaxMessageSize: 0, } } // Start the SCTPTransport. Since both local and remote parties must mutually // create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish // a connection over SCTP. func (r *SCTPTransport) Start(SCTPCapabilities) error { if r.isStarted { return nil } r.isStarted = true dtlsTransport := r.Transport() if dtlsTransport == nil || dtlsTransport.conn == nil { return errSCTPTransportDTLS } sctpAssociation, err := sctp.Client(sctp.Config{ NetConn: dtlsTransport.conn, MaxReceiveBufferSize: r.api.settingEngine.sctp.maxReceiveBufferSize, EnableZeroChecksum: r.api.settingEngine.sctp.enableZeroChecksum, LoggerFactory: r.api.settingEngine.LoggerFactory, }) if err != nil { return err } r.lock.Lock() r.sctpAssociation = sctpAssociation r.state = SCTPTransportStateConnected dataChannels := append([]*DataChannel{}, r.dataChannels...) r.lock.Unlock() var openedDCCount uint32 for _, d := range dataChannels { if d.ReadyState() == DataChannelStateConnecting { err := d.open(r) if err != nil { r.log.Warnf("failed to open data channel: %s", err) continue } openedDCCount++ } } r.lock.Lock() r.dataChannelsOpened += openedDCCount r.lock.Unlock() go r.acceptDataChannels(sctpAssociation) return nil } // Stop stops the SCTPTransport func (r *SCTPTransport) Stop() error { r.lock.Lock() defer r.lock.Unlock() if r.sctpAssociation == nil { return nil } err := r.sctpAssociation.Close() if err != nil { return err } r.sctpAssociation = nil r.state = SCTPTransportStateClosed return nil } func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) { r.lock.RLock() dataChannels := make([]*datachannel.DataChannel, 0, len(r.dataChannels)) for _, dc := range r.dataChannels { dc.mu.Lock() isNil := dc.dataChannel == nil dc.mu.Unlock() if isNil { continue } dataChannels = append(dataChannels, dc.dataChannel) } r.lock.RUnlock() ACCEPT: for { dc, err := datachannel.Accept(a, &datachannel.Config{ LoggerFactory: r.api.settingEngine.LoggerFactory, }, dataChannels...) if err != nil { if !errors.Is(err, io.EOF) { r.log.Errorf("Failed to accept data channel: %v", err) r.onError(err) } return } for _, ch := range dataChannels { if ch.StreamIdentifier() == dc.StreamIdentifier() { continue ACCEPT } } var ( maxRetransmits *uint16 maxPacketLifeTime *uint16 ) val := uint16(dc.Config.ReliabilityParameter) ordered := true switch dc.Config.ChannelType { case datachannel.ChannelTypeReliable: ordered = true case datachannel.ChannelTypeReliableUnordered: ordered = false case datachannel.ChannelTypePartialReliableRexmit: ordered = true maxRetransmits = &val case datachannel.ChannelTypePartialReliableRexmitUnordered: ordered = false maxRetransmits = &val case datachannel.ChannelTypePartialReliableTimed: ordered = true maxPacketLifeTime = &val case datachannel.ChannelTypePartialReliableTimedUnordered: ordered = false maxPacketLifeTime = &val default: } sid := dc.StreamIdentifier() rtcDC, err := r.api.newDataChannel(&DataChannelParameters{ ID: &sid, Label: dc.Config.Label, Protocol: dc.Config.Protocol, Negotiated: dc.Config.Negotiated, Ordered: ordered, MaxPacketLifeTime: maxPacketLifeTime, MaxRetransmits: maxRetransmits, }, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc")) if err != nil { r.log.Errorf("Failed to accept data channel: %v", err) r.onError(err) return } <-r.onDataChannel(rtcDC) rtcDC.handleOpen(dc, true, dc.Config.Negotiated) r.lock.Lock() r.dataChannelsOpened++ handler := r.onDataChannelOpenedHandler r.lock.Unlock() if handler != nil { handler(rtcDC) } } } // OnError sets an event handler which is invoked when // the SCTP connection error occurs. func (r *SCTPTransport) OnError(f func(err error)) { r.lock.Lock() defer r.lock.Unlock() r.onErrorHandler = f } func (r *SCTPTransport) onError(err error) { r.lock.RLock() handler := r.onErrorHandler r.lock.RUnlock() if handler != nil { go handler(err) } } // OnDataChannel sets an event handler which is invoked when a data // channel message arrives from a remote peer. func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) { r.lock.Lock() defer r.lock.Unlock() r.onDataChannelHandler = f } // OnDataChannelOpened sets an event handler which is invoked when a data // channel is opened func (r *SCTPTransport) OnDataChannelOpened(f func(*DataChannel)) { r.lock.Lock() defer r.lock.Unlock() r.onDataChannelOpenedHandler = f } func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) { r.lock.Lock() r.dataChannels = append(r.dataChannels, dc) r.dataChannelsAccepted++ handler := r.onDataChannelHandler r.lock.Unlock() done = make(chan struct{}) if handler == nil || dc == nil { close(done) return } // Run this synchronously to allow setup done in onDataChannelFn() // to complete before datachannel event handlers might be called. go func() { handler(dc) close(done) }() return } func (r *SCTPTransport) updateMessageSize() { r.lock.Lock() defer r.lock.Unlock() var remoteMaxMessageSize float64 = 65536 // pion/webrtc#758 var canSendSize float64 = 65536 // pion/webrtc#758 r.maxMessageSize = r.calcMessageSize(remoteMaxMessageSize, canSendSize) } func (r *SCTPTransport) calcMessageSize(remoteMaxMessageSize, canSendSize float64) float64 { switch { case remoteMaxMessageSize == 0 && canSendSize == 0: return math.Inf(1) case remoteMaxMessageSize == 0: return canSendSize case canSendSize == 0: return remoteMaxMessageSize case canSendSize > remoteMaxMessageSize: return remoteMaxMessageSize default: return canSendSize } } func (r *SCTPTransport) updateMaxChannels() { val := sctpMaxChannels r.maxChannels = &val } // MaxChannels is the maximum number of RTCDataChannels that can be open simultaneously. func (r *SCTPTransport) MaxChannels() uint16 { r.lock.Lock() defer r.lock.Unlock() if r.maxChannels == nil { return sctpMaxChannels } return *r.maxChannels } // State returns the current state of the SCTPTransport func (r *SCTPTransport) State() SCTPTransportState { r.lock.RLock() defer r.lock.RUnlock() return r.state } func (r *SCTPTransport) collectStats(collector *statsReportCollector) { collector.Collecting() stats := SCTPTransportStats{ Timestamp: statsTimestampFrom(time.Now()), Type: StatsTypeSCTPTransport, ID: "sctpTransport", } association := r.association() if association != nil { stats.BytesSent = association.BytesSent() stats.BytesReceived = association.BytesReceived() stats.SmoothedRoundTripTime = association.SRTT() * 0.001 // convert milliseconds to seconds stats.CongestionWindow = association.CWND() stats.ReceiverWindow = association.RWND() stats.MTU = association.MTU() } collector.Collect(stats.ID, stats) } func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **uint16) error { var id uint16 if dtlsRole != DTLSRoleClient { id++ } max := r.MaxChannels() r.lock.Lock() defer r.lock.Unlock() // Create map of ids so we can compare without double-looping each time. idsMap := make(map[uint16]struct{}, len(r.dataChannels)) for _, dc := range r.dataChannels { if dc.ID() == nil { continue } idsMap[*dc.ID()] = struct{}{} } for ; id < max-1; id += 2 { if _, ok := idsMap[id]; ok { continue } *idOut = &id return nil } return &rtcerr.OperationError{Err: ErrMaxDataChannelID} } func (r *SCTPTransport) association() *sctp.Association { if r == nil { return nil } r.lock.RLock() association := r.sctpAssociation r.lock.RUnlock() return association }