2024-05-15 19:15:00 -04:00

162 lines
3.4 KiB
Go

// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package srtp
import (
"errors"
"io"
"net"
"sync"
"time"
"github.com/pion/logging"
"github.com/pion/transport/v2/packetio"
)
type streamSession interface {
Close() error
write([]byte) (int, error)
decrypt([]byte) error
}
type session struct {
localContextMutex sync.Mutex
localContext, remoteContext *Context
localOptions, remoteOptions []ContextOption
newStream chan readStream
acceptStreamTimeout time.Time
started chan interface{}
closed chan interface{}
readStreamsClosed bool
readStreams map[uint32]readStream
readStreamsLock sync.Mutex
log logging.LeveledLogger
bufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser
nextConn net.Conn
}
// Config is used to configure a session.
// You can provide either a KeyingMaterialExporter to export keys
// or directly pass the keys themselves.
// After a Config is passed to a session it must not be modified.
type Config struct {
Keys SessionKeys
Profile ProtectionProfile
BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser
LoggerFactory logging.LoggerFactory
AcceptStreamTimeout time.Time
// List of local/remote context options.
// ReplayProtection is enabled on remote context by default.
// Default replay protection window size is 64.
LocalOptions, RemoteOptions []ContextOption
}
// SessionKeys bundles the keys required to setup an SRTP session
type SessionKeys struct {
LocalMasterKey []byte
LocalMasterSalt []byte
RemoteMasterKey []byte
RemoteMasterSalt []byte
}
func (s *session) getOrCreateReadStream(ssrc uint32, child streamSession, proto func() readStream) (readStream, bool) {
s.readStreamsLock.Lock()
defer s.readStreamsLock.Unlock()
if s.readStreamsClosed {
return nil, false
}
r, ok := s.readStreams[ssrc]
if ok {
return r, false
}
// Create the readStream.
r = proto()
if err := r.init(child, ssrc); err != nil {
return nil, false
}
s.readStreams[ssrc] = r
return r, true
}
func (s *session) removeReadStream(ssrc uint32) {
s.readStreamsLock.Lock()
defer s.readStreamsLock.Unlock()
if s.readStreamsClosed {
return
}
delete(s.readStreams, ssrc)
}
func (s *session) close() error {
if s.nextConn == nil {
return nil
} else if err := s.nextConn.Close(); err != nil {
return err
}
<-s.closed
return nil
}
func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error {
var err error
s.localContext, err = CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...)
if err != nil {
return err
}
s.remoteContext, err = CreateContext(remoteMasterKey, remoteMasterSalt, profile, s.remoteOptions...)
if err != nil {
return err
}
if err = s.nextConn.SetReadDeadline(s.acceptStreamTimeout); err != nil {
return err
}
go func() {
defer func() {
close(s.newStream)
s.readStreamsLock.Lock()
s.readStreamsClosed = true
s.readStreamsLock.Unlock()
close(s.closed)
}()
b := make([]byte, 8192)
for {
var i int
i, err = s.nextConn.Read(b)
if err != nil {
if !errors.Is(err, io.EOF) {
s.log.Error(err.Error())
}
return
}
if err = child.decrypt(b[:i]); err != nil {
s.log.Info(err.Error())
}
}
}()
close(s.started)
return nil
}