225 lines
5.8 KiB
Go
Raw Normal View History

2024-05-15 19:15:00 -04:00
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
2022-03-10 10:44:48 +01:00
package srtp
import (
"fmt"
2024-05-15 19:15:00 -04:00
"github.com/pion/transport/v2/replaydetector"
2022-03-10 10:44:48 +01:00
)
const (
labelSRTPEncryption = 0x00
labelSRTPAuthenticationTag = 0x01
labelSRTPSalt = 0x02
labelSRTCPEncryption = 0x03
labelSRTCPAuthenticationTag = 0x04
labelSRTCPSalt = 0x05
maxSequenceNumber = 65535
2024-05-15 19:15:00 -04:00
maxROC = (1 << 32) - 1
2022-03-10 10:44:48 +01:00
seqNumMedian = 1 << 15
seqNumMax = 1 << 16
srtcpIndexSize = 4
)
// Encrypt/Decrypt state for a single SRTP SSRC
type srtpSSRCState struct {
ssrc uint32
rolloverHasProcessed bool
2024-05-15 19:15:00 -04:00
index uint64
2022-03-10 10:44:48 +01:00
replayDetector replaydetector.ReplayDetector
}
// Encrypt/Decrypt state for a single SRTCP SSRC
type srtcpSSRCState struct {
srtcpIndex uint32
ssrc uint32
replayDetector replaydetector.ReplayDetector
}
// Context represents a SRTP cryptographic context.
// Context can only be used for one-way operations.
// it must either used ONLY for encryption or ONLY for decryption.
2024-05-15 19:15:00 -04:00
// Note that Context does not provide any concurrency protection:
// access to a Context from multiple goroutines requires external
// synchronization.
2022-03-10 10:44:48 +01:00
type Context struct {
cipher srtpCipher
srtpSSRCStates map[uint32]*srtpSSRCState
srtcpSSRCStates map[uint32]*srtcpSSRCState
newSRTCPReplayDetector func() replaydetector.ReplayDetector
newSRTPReplayDetector func() replaydetector.ReplayDetector
}
// CreateContext creates a new SRTP Context.
//
// CreateContext receives variable number of ContextOption-s.
// Passing multiple options which set the same parameter let the last one valid.
// Following example create SRTP Context with replay protection with window size of 256.
//
2024-05-15 19:15:00 -04:00
// decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
2022-03-10 10:44:48 +01:00
func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) {
keyLen, err := profile.keyLen()
if err != nil {
return nil, err
}
saltLen, err := profile.saltLen()
if err != nil {
return nil, err
}
if masterKeyLen := len(masterKey); masterKeyLen != keyLen {
return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen)
} else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen {
return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen)
}
c = &Context{
srtpSSRCStates: map[uint32]*srtpSSRCState{},
srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
}
switch profile {
2024-05-15 19:15:00 -04:00
case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
c.cipher, err = newSrtpCipherAeadAesGcm(profile, masterKey, masterSalt)
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80:
c.cipher, err = newSrtpCipherAesCmHmacSha1(profile, masterKey, masterSalt)
2022-03-10 10:44:48 +01:00
default:
return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile)
}
if err != nil {
return nil, err
}
for _, o := range append(
[]ContextOption{ // Default options
SRTPNoReplayProtection(),
SRTCPNoReplayProtection(),
},
opts..., // User specified options
) {
if errOpt := o(c); errOpt != nil {
return nil, errOpt
}
}
return c, nil
}
// https://tools.ietf.org/html/rfc3550#appendix-A.1
2024-05-15 19:15:00 -04:00
func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (roc uint32, diff int32, overflow bool) {
2022-03-10 10:44:48 +01:00
seq := int32(sequenceNumber)
localRoc := uint32(s.index >> 16)
localSeq := int32(s.index & (seqNumMax - 1))
guessRoc := localRoc
2024-05-15 19:15:00 -04:00
var difference int32
2022-03-10 10:44:48 +01:00
if s.rolloverHasProcessed {
// When localROC is equal to 0, and entering seq-localSeq > seqNumMedian
// judgment, it will cause guessRoc calculation error
if s.index > seqNumMedian {
if localSeq < seqNumMedian {
if seq-localSeq > seqNumMedian {
guessRoc = localRoc - 1
difference = seq - localSeq - seqNumMax
} else {
guessRoc = localRoc
difference = seq - localSeq
}
} else {
if localSeq-seqNumMedian > seq {
guessRoc = localRoc + 1
difference = seq - localSeq + seqNumMax
} else {
guessRoc = localRoc
difference = seq - localSeq
}
}
} else {
// localRoc is equal to 0
difference = seq - localSeq
}
}
2024-05-15 19:15:00 -04:00
return guessRoc, difference, (guessRoc == 0 && localRoc == maxROC)
}
func (s *srtpSSRCState) updateRolloverCount(sequenceNumber uint16, difference int32) {
if !s.rolloverHasProcessed {
s.index |= uint64(sequenceNumber)
s.rolloverHasProcessed = true
return
}
if difference > 0 {
s.index += uint64(difference)
2022-03-10 10:44:48 +01:00
}
}
func (c *Context) getSRTPSSRCState(ssrc uint32) *srtpSSRCState {
s, ok := c.srtpSSRCStates[ssrc]
if ok {
return s
}
s = &srtpSSRCState{
ssrc: ssrc,
replayDetector: c.newSRTPReplayDetector(),
}
c.srtpSSRCStates[ssrc] = s
return s
}
func (c *Context) getSRTCPSSRCState(ssrc uint32) *srtcpSSRCState {
s, ok := c.srtcpSSRCStates[ssrc]
if ok {
return s
}
s = &srtcpSSRCState{
ssrc: ssrc,
replayDetector: c.newSRTCPReplayDetector(),
}
c.srtcpSSRCStates[ssrc] = s
return s
}
// ROC returns SRTP rollover counter value of specified SSRC.
func (c *Context) ROC(ssrc uint32) (uint32, bool) {
s, ok := c.srtpSSRCStates[ssrc]
if !ok {
return 0, false
}
return uint32(s.index >> 16), true
}
// SetROC sets SRTP rollover counter value of specified SSRC.
func (c *Context) SetROC(ssrc uint32, roc uint32) {
s := c.getSRTPSSRCState(ssrc)
2024-05-15 19:15:00 -04:00
s.index = uint64(roc) << 16
s.rolloverHasProcessed = false
2022-03-10 10:44:48 +01:00
}
// Index returns SRTCP index value of specified SSRC.
func (c *Context) Index(ssrc uint32) (uint32, bool) {
s, ok := c.srtcpSSRCStates[ssrc]
if !ok {
return 0, false
}
return s.srtcpIndex, true
}
// SetIndex sets SRTCP index value of specified SSRC.
func (c *Context) SetIndex(ssrc uint32, index uint32) {
s := c.getSRTCPSSRCState(ssrc)
s.srtcpIndex = index % (maxSRTCPIndex + 1)
}