216 lines
5.3 KiB
Go
Raw Normal View History

2022-03-10 10:44:48 +01:00
package srtp
import (
"fmt"
"github.com/pion/transport/replaydetector"
)
const (
labelSRTPEncryption = 0x00
labelSRTPAuthenticationTag = 0x01
labelSRTPSalt = 0x02
labelSRTCPEncryption = 0x03
labelSRTCPAuthenticationTag = 0x04
labelSRTCPSalt = 0x05
maxSequenceNumber = 65535
seqNumMedian = 1 << 15
seqNumMax = 1 << 16
srtcpIndexSize = 4
)
// Encrypt/Decrypt state for a single SRTP SSRC
type srtpSSRCState struct {
ssrc uint32
index uint64
rolloverHasProcessed bool
replayDetector replaydetector.ReplayDetector
}
// Encrypt/Decrypt state for a single SRTCP SSRC
type srtcpSSRCState struct {
srtcpIndex uint32
ssrc uint32
replayDetector replaydetector.ReplayDetector
}
// Context represents a SRTP cryptographic context.
// Context can only be used for one-way operations.
// it must either used ONLY for encryption or ONLY for decryption.
type Context struct {
cipher srtpCipher
srtpSSRCStates map[uint32]*srtpSSRCState
srtcpSSRCStates map[uint32]*srtcpSSRCState
newSRTCPReplayDetector func() replaydetector.ReplayDetector
newSRTPReplayDetector func() replaydetector.ReplayDetector
}
// CreateContext creates a new SRTP Context.
//
// CreateContext receives variable number of ContextOption-s.
// Passing multiple options which set the same parameter let the last one valid.
// Following example create SRTP Context with replay protection with window size of 256.
//
// decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
//
func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) {
keyLen, err := profile.keyLen()
if err != nil {
return nil, err
}
saltLen, err := profile.saltLen()
if err != nil {
return nil, err
}
if masterKeyLen := len(masterKey); masterKeyLen != keyLen {
return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen)
} else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen {
return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen)
}
c = &Context{
srtpSSRCStates: map[uint32]*srtpSSRCState{},
srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
}
switch profile {
case ProtectionProfileAeadAes128Gcm:
c.cipher, err = newSrtpCipherAeadAesGcm(masterKey, masterSalt)
case ProtectionProfileAes128CmHmacSha1_80:
c.cipher, err = newSrtpCipherAesCmHmacSha1(masterKey, masterSalt)
default:
return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile)
}
if err != nil {
return nil, err
}
for _, o := range append(
[]ContextOption{ // Default options
SRTPNoReplayProtection(),
SRTCPNoReplayProtection(),
},
opts..., // User specified options
) {
if errOpt := o(c); errOpt != nil {
return nil, errOpt
}
}
return c, nil
}
// https://tools.ietf.org/html/rfc3550#appendix-A.1
func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (uint32, func()) {
seq := int32(sequenceNumber)
localRoc := uint32(s.index >> 16)
localSeq := int32(s.index & (seqNumMax - 1))
guessRoc := localRoc
var difference int32 = 0
if s.rolloverHasProcessed {
// When localROC is equal to 0, and entering seq-localSeq > seqNumMedian
// judgment, it will cause guessRoc calculation error
if s.index > seqNumMedian {
if localSeq < seqNumMedian {
if seq-localSeq > seqNumMedian {
guessRoc = localRoc - 1
difference = seq - localSeq - seqNumMax
} else {
guessRoc = localRoc
difference = seq - localSeq
}
} else {
if localSeq-seqNumMedian > seq {
guessRoc = localRoc + 1
difference = seq - localSeq + seqNumMax
} else {
guessRoc = localRoc
difference = seq - localSeq
}
}
} else {
// localRoc is equal to 0
difference = seq - localSeq
}
}
return guessRoc, func() {
if !s.rolloverHasProcessed {
s.index |= uint64(sequenceNumber)
s.rolloverHasProcessed = true
return
}
if difference > 0 {
s.index += uint64(difference)
}
}
}
func (c *Context) getSRTPSSRCState(ssrc uint32) *srtpSSRCState {
s, ok := c.srtpSSRCStates[ssrc]
if ok {
return s
}
s = &srtpSSRCState{
ssrc: ssrc,
replayDetector: c.newSRTPReplayDetector(),
}
c.srtpSSRCStates[ssrc] = s
return s
}
func (c *Context) getSRTCPSSRCState(ssrc uint32) *srtcpSSRCState {
s, ok := c.srtcpSSRCStates[ssrc]
if ok {
return s
}
s = &srtcpSSRCState{
ssrc: ssrc,
replayDetector: c.newSRTCPReplayDetector(),
}
c.srtcpSSRCStates[ssrc] = s
return s
}
// ROC returns SRTP rollover counter value of specified SSRC.
func (c *Context) ROC(ssrc uint32) (uint32, bool) {
s, ok := c.srtpSSRCStates[ssrc]
if !ok {
return 0, false
}
return uint32(s.index >> 16), true
}
// SetROC sets SRTP rollover counter value of specified SSRC.
func (c *Context) SetROC(ssrc uint32, roc uint32) {
s := c.getSRTPSSRCState(ssrc)
s.index = uint64(roc<<16) | (s.index & (seqNumMax - 1))
}
// Index returns SRTCP index value of specified SSRC.
func (c *Context) Index(ssrc uint32) (uint32, bool) {
s, ok := c.srtcpSSRCStates[ssrc]
if !ok {
return 0, false
}
return s.srtcpIndex, true
}
// SetIndex sets SRTCP index value of specified SSRC.
func (c *Context) SetIndex(ssrc uint32, index uint32) {
s := c.getSRTCPSSRCState(ssrc)
s.srtcpIndex = index % (maxSRTCPIndex + 1)
}