// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package srtp import ( "crypto/aes" "encoding/binary" ) func aesCmKeyDerivation(label byte, masterKey, masterSalt []byte, indexOverKdr int, outLen int) ([]byte, error) { if indexOverKdr != 0 { // 24-bit "index DIV kdr" must be xored to prf input. return nil, errNonZeroKDRNotSupported } // https://tools.ietf.org/html/rfc3711#appendix-B.3 // The input block for AES-CM is generated by exclusive-oring the master salt with the // concatenation of the encryption key label 0x00 with (index DIV kdr), // - index is 'rollover count' and DIV is 'divided by' nMasterKey := len(masterKey) nMasterSalt := len(masterSalt) prfIn := make([]byte, 16) copy(prfIn[:nMasterSalt], masterSalt) prfIn[7] ^= label // The resulting value is then AES encrypted using the master key to get the cipher key. block, err := aes.NewCipher(masterKey) if err != nil { return nil, err } out := make([]byte, ((outLen+nMasterKey)/nMasterKey)*nMasterKey) var i uint16 for n := 0; n < outLen; n += block.BlockSize() { binary.BigEndian.PutUint16(prfIn[len(prfIn)-2:], i) block.Encrypt(out[n:n+nMasterKey], prfIn) i++ } return out[:outLen], nil } // Generate IV https://tools.ietf.org/html/rfc3711#section-4.1.1 // where the 128-bit integer value IV SHALL be defined by the SSRC, the // SRTP packet index i, and the SRTP session salting key k_s, as below. // - ROC = a 32-bit unsigned rollover counter (ROC), which records how many // - times the 16-bit RTP sequence number has been reset to zero after // - passing through 65,535 // i = 2^16 * ROC + SEQ // IV = (salt*2 ^ 16) | (ssrc*2 ^ 64) | (i*2 ^ 16) func generateCounter(sequenceNumber uint16, rolloverCounter uint32, ssrc uint32, sessionSalt []byte) (counter [16]byte) { copy(counter[:], sessionSalt) counter[4] ^= byte(ssrc >> 24) counter[5] ^= byte(ssrc >> 16) counter[6] ^= byte(ssrc >> 8) counter[7] ^= byte(ssrc) counter[8] ^= byte(rolloverCounter >> 24) counter[9] ^= byte(rolloverCounter >> 16) counter[10] ^= byte(rolloverCounter >> 8) counter[11] ^= byte(rolloverCounter) counter[12] ^= byte(sequenceNumber >> 8) counter[13] ^= byte(sequenceNumber) return counter }