status-go/vendor/github.com/pion/dtls/v2/flight4bhandler.go

145 lines
4.6 KiB
Go
Raw Normal View History

2024-05-15 19:15:00 -04:00
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
2022-03-10 10:44:48 +01:00
package dtls
import (
"bytes"
"context"
"github.com/pion/dtls/v2/pkg/crypto/prf"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/alert"
"github.com/pion/dtls/v2/pkg/protocol/extension"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)
2024-05-15 19:15:00 -04:00
func flight4bParse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
_, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
2022-03-10 10:44:48 +01:00
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
)
if !ok {
// No valid message received. Keep reading
return 0, nil, nil
}
var finished *handshake.MessageFinished
if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
}
plainText := cache.pullAndMerge(
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
)
expectedVerifyData, err := prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc())
if err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
}
// Other party may re-transmit the last flight. Keep state to be flight4b.
return flight4b, nil, nil
}
2024-05-15 19:15:00 -04:00
func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
2022-03-10 10:44:48 +01:00
var pkts []*packet
extensions := []extension.Extension{&extension.RenegotiationInfo{
RenegotiatedConnection: 0,
}}
if (cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret {
extensions = append(extensions, &extension.UseExtendedMasterSecret{
Supported: true,
})
}
2024-06-05 16:10:03 -04:00
if state.getSRTPProtectionProfile() != 0 {
2022-03-10 10:44:48 +01:00
extensions = append(extensions, &extension.UseSRTP{
2024-06-05 16:10:03 -04:00
ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()},
2022-03-10 10:44:48 +01:00
})
}
selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols)
if err != nil {
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err
}
if selectedProto != "" {
extensions = append(extensions, &extension.ALPN{
ProtocolNameList: []string{selectedProto},
})
state.NegotiatedProtocol = selectedProto
}
cipherSuiteID := uint16(state.cipherSuite.ID())
serverHello := &handshake.Handshake{
Message: &handshake.MessageServerHello{
Version: protocol.Version1_2,
Random: state.localRandom,
SessionID: state.SessionID,
CipherSuiteID: &cipherSuiteID,
CompressionMethod: defaultCompressionMethods()[0],
Extensions: extensions,
},
}
serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence)
if len(state.localVerifyData) == 0 {
plainText := cache.pullAndMerge(
handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
)
raw, err := serverHello.Marshal()
if err != nil {
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
plainText = append(plainText, raw...)
state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
if err != nil {
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
}
pkts = append(pkts,
&packet{
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: serverHello,
},
},
&packet{
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: &protocol.ChangeCipherSpec{},
},
},
&packet{
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
Epoch: 1,
},
Content: &handshake.Handshake{
Message: &handshake.MessageFinished{
VerifyData: state.localVerifyData,
},
},
},
shouldEncrypt: true,
resetLocalSequenceNumber: true,
},
)
return pkts, nil, nil
}