status-go/vendor/github.com/pion/srtp/v2/session_srtcp.go

181 lines
4.3 KiB
Go

package srtp
import (
"net"
"time"
"github.com/pion/logging"
"github.com/pion/rtcp"
)
const defaultSessionSRTCPReplayProtectionWindow = 64
// SessionSRTCP implements io.ReadWriteCloser and provides a bi-directional SRTCP session
// SRTCP itself does not have a design like this, but it is common in most applications
// for local/remote to each have their own keying material. This provides those patterns
// instead of making everyone re-implement
type SessionSRTCP struct {
session
writeStream *WriteStreamSRTCP
}
// NewSessionSRTCP creates a SRTCP session using conn as the underlying transport.
func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //nolint:dupl
if config == nil {
return nil, errNoConfig
} else if conn == nil {
return nil, errNoConn
}
loggerFactory := config.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
}
localOpts := append(
[]ContextOption{},
config.LocalOptions...,
)
remoteOpts := append(
[]ContextOption{
// Default options
SRTCPReplayProtection(defaultSessionSRTCPReplayProtectionWindow),
},
config.RemoteOptions...,
)
s := &SessionSRTCP{
session: session{
nextConn: conn,
localOptions: localOpts,
remoteOptions: remoteOpts,
readStreams: map[uint32]readStream{},
newStream: make(chan readStream),
started: make(chan interface{}),
closed: make(chan interface{}),
bufferFactory: config.BufferFactory,
log: loggerFactory.NewLogger("srtp"),
},
}
s.writeStream = &WriteStreamSRTCP{s}
err := s.session.start(
config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt,
config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt,
config.Profile,
s,
)
if err != nil {
return nil, err
}
return s, nil
}
// OpenWriteStream returns the global write stream for the Session
func (s *SessionSRTCP) OpenWriteStream() (*WriteStreamSRTCP, error) {
return s.writeStream, nil
}
// OpenReadStream opens a read stream for the given SSRC, it can be used
// if you want a certain SSRC, but don't want to wait for AcceptStream
func (s *SessionSRTCP) OpenReadStream(ssrc uint32) (*ReadStreamSRTCP, error) {
r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP)
if readStream, ok := r.(*ReadStreamSRTCP); ok {
return readStream, nil
}
return nil, errFailedTypeAssertion
}
// AcceptStream returns a stream to handle RTCP for a single SSRC
func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) {
stream, ok := <-s.newStream
if !ok {
return nil, 0, errStreamAlreadyClosed
}
readStream, ok := stream.(*ReadStreamSRTCP)
if !ok {
return nil, 0, errFailedTypeAssertion
}
return readStream, stream.GetSSRC(), nil
}
// Close ends the session
func (s *SessionSRTCP) Close() error {
return s.session.close()
}
// Private
func (s *SessionSRTCP) write(buf []byte) (int, error) {
if _, ok := <-s.session.started; ok {
return 0, errStartedChannelUsedIncorrectly
}
s.session.localContextMutex.Lock()
encrypted, err := s.localContext.EncryptRTCP(nil, buf, nil)
s.session.localContextMutex.Unlock()
if err != nil {
return 0, err
}
return s.session.nextConn.Write(encrypted)
}
func (s *SessionSRTCP) setWriteDeadline(t time.Time) error {
return s.session.nextConn.SetWriteDeadline(t)
}
// create a list of Destination SSRCs
// that's a superset of all Destinations in the slice.
func destinationSSRC(pkts []rtcp.Packet) []uint32 {
ssrcSet := make(map[uint32]struct{})
for _, p := range pkts {
for _, ssrc := range p.DestinationSSRC() {
ssrcSet[ssrc] = struct{}{}
}
}
out := make([]uint32, 0, len(ssrcSet))
for ssrc := range ssrcSet {
out = append(out, ssrc)
}
return out
}
func (s *SessionSRTCP) decrypt(buf []byte) error {
decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil)
if err != nil {
return err
}
pkt, err := rtcp.Unmarshal(decrypted)
if err != nil {
return err
}
for _, ssrc := range destinationSSRC(pkt) {
r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP)
if r == nil {
return nil // Session has been closed
} else if isNew {
s.session.newStream <- r // Notify AcceptStream
}
readStream, ok := r.(*ReadStreamSRTCP)
if !ok {
return errFailedTypeAssertion
}
_, err = readStream.write(decrypted)
if err != nil {
return err
}
}
return nil
}