229 lines
6.5 KiB
Go
Raw Normal View History

2024-05-15 19:15:00 -04:00
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
2022-03-10 10:44:48 +01:00
package dtls
import (
"bytes"
"encoding/gob"
"sync/atomic"
"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/crypto/prf"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
2024-05-15 19:15:00 -04:00
"github.com/pion/transport/v2/replaydetector"
2022-03-10 10:44:48 +01:00
)
// State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
type State struct {
localEpoch, remoteEpoch atomic.Value
localSequenceNumber []uint64 // uint48
localRandom, remoteRandom handshake.Random
masterSecret []byte
cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen
2024-06-05 16:10:03 -04:00
srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile
2022-03-10 10:44:48 +01:00
PeerCertificates [][]byte
IdentityHint []byte
SessionID []byte
isClient bool
preMasterSecret []byte
extendedMasterSecret bool
namedCurve elliptic.Curve
localKeypair *elliptic.Keypair
cookie []byte
handshakeSendSequence int
handshakeRecvSequence int
serverName string
remoteRequestedCertificate bool // Did we get a CertificateRequest
localCertificatesVerify []byte // cache CertificateVerify
localVerifyData []byte // cached VerifyData
localKeySignature []byte // cached keySignature
peerCertificatesVerified bool
replayDetector []replaydetector.ReplayDetector
peerSupportedProtocols []string
NegotiatedProtocol string
}
type serializedState struct {
LocalEpoch uint16
RemoteEpoch uint16
LocalRandom [handshake.RandomLength]byte
RemoteRandom [handshake.RandomLength]byte
CipherSuiteID uint16
MasterSecret []byte
SequenceNumber uint64
SRTPProtectionProfile uint16
PeerCertificates [][]byte
IdentityHint []byte
SessionID []byte
IsClient bool
}
func (s *State) clone() *State {
serialized := s.serialize()
state := &State{}
state.deserialize(*serialized)
return state
}
func (s *State) serialize() *serializedState {
// Marshal random values
localRnd := s.localRandom.MarshalFixed()
remoteRnd := s.remoteRandom.MarshalFixed()
2024-05-15 19:15:00 -04:00
epoch := s.getLocalEpoch()
2022-03-10 10:44:48 +01:00
return &serializedState{
2024-05-15 19:15:00 -04:00
LocalEpoch: s.getLocalEpoch(),
RemoteEpoch: s.getRemoteEpoch(),
2022-03-10 10:44:48 +01:00
CipherSuiteID: uint16(s.cipherSuite.ID()),
MasterSecret: s.masterSecret,
SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]),
LocalRandom: localRnd,
RemoteRandom: remoteRnd,
2024-06-05 16:10:03 -04:00
SRTPProtectionProfile: uint16(s.getSRTPProtectionProfile()),
2022-03-10 10:44:48 +01:00
PeerCertificates: s.PeerCertificates,
IdentityHint: s.IdentityHint,
SessionID: s.SessionID,
IsClient: s.isClient,
}
}
func (s *State) deserialize(serialized serializedState) {
// Set epoch values
epoch := serialized.LocalEpoch
s.localEpoch.Store(serialized.LocalEpoch)
s.remoteEpoch.Store(serialized.RemoteEpoch)
for len(s.localSequenceNumber) <= int(epoch) {
s.localSequenceNumber = append(s.localSequenceNumber, uint64(0))
}
// Set random values
localRandom := &handshake.Random{}
localRandom.UnmarshalFixed(serialized.LocalRandom)
s.localRandom = *localRandom
remoteRandom := &handshake.Random{}
remoteRandom.UnmarshalFixed(serialized.RemoteRandom)
s.remoteRandom = *remoteRandom
s.isClient = serialized.IsClient
// Set master secret
s.masterSecret = serialized.MasterSecret
// Set cipher suite
s.cipherSuite = cipherSuiteForID(CipherSuiteID(serialized.CipherSuiteID), nil)
atomic.StoreUint64(&s.localSequenceNumber[epoch], serialized.SequenceNumber)
2024-06-05 16:10:03 -04:00
s.setSRTPProtectionProfile(SRTPProtectionProfile(serialized.SRTPProtectionProfile))
2022-03-10 10:44:48 +01:00
// Set remote certificate
s.PeerCertificates = serialized.PeerCertificates
s.IdentityHint = serialized.IdentityHint
s.SessionID = serialized.SessionID
}
func (s *State) initCipherSuite() error {
if s.cipherSuite.IsInitialized() {
return nil
}
localRandom := s.localRandom.MarshalFixed()
remoteRandom := s.remoteRandom.MarshalFixed()
var err error
if s.isClient {
err = s.cipherSuite.Init(s.masterSecret, localRandom[:], remoteRandom[:], true)
} else {
err = s.cipherSuite.Init(s.masterSecret, remoteRandom[:], localRandom[:], false)
}
if err != nil {
return err
}
return nil
}
// MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation
func (s *State) MarshalBinary() ([]byte, error) {
serialized := s.serialize()
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(*serialized); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// UnmarshalBinary is a binary.BinaryUnmarshaler.UnmarshalBinary implementation
func (s *State) UnmarshalBinary(data []byte) error {
enc := gob.NewDecoder(bytes.NewBuffer(data))
var serialized serializedState
if err := enc.Decode(&serialized); err != nil {
return err
}
s.deserialize(serialized)
2024-05-15 19:15:00 -04:00
return s.initCipherSuite()
2022-03-10 10:44:48 +01:00
}
// ExportKeyingMaterial returns length bytes of exported key material in a new
// slice as defined in RFC 5705.
// This allows protocols to use DTLS for key establishment, but
// then use some of the keying material for their own purposes
func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
2024-05-15 19:15:00 -04:00
if s.getLocalEpoch() == 0 {
2022-03-10 10:44:48 +01:00
return nil, errHandshakeInProgress
} else if len(context) != 0 {
return nil, errContextUnsupported
} else if _, ok := invalidKeyingLabels()[label]; ok {
return nil, errReservedExportKeyingMaterial
}
localRandom := s.localRandom.MarshalFixed()
remoteRandom := s.remoteRandom.MarshalFixed()
seed := []byte(label)
if s.isClient {
seed = append(append(seed, localRandom[:]...), remoteRandom[:]...)
} else {
seed = append(append(seed, remoteRandom[:]...), localRandom[:]...)
}
return prf.PHash(s.masterSecret, seed, length, s.cipherSuite.HashFunc())
}
2024-05-15 19:15:00 -04:00
func (s *State) getRemoteEpoch() uint16 {
if remoteEpoch, ok := s.remoteEpoch.Load().(uint16); ok {
return remoteEpoch
}
return 0
}
func (s *State) getLocalEpoch() uint16 {
if localEpoch, ok := s.localEpoch.Load().(uint16); ok {
return localEpoch
}
return 0
}
2024-06-05 16:10:03 -04:00
func (s *State) setSRTPProtectionProfile(profile SRTPProtectionProfile) {
s.srtpProtectionProfile.Store(profile)
}
func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile {
if val, ok := s.srtpProtectionProfile.Load().(SRTPProtectionProfile); ok {
return val
}
return 0
}