refactor: deduplicate WriteMessage content

This commit is contained in:
Richard Ramos 2022-10-23 09:01:24 -04:00
parent d794400c8f
commit 4fcac0b407
No known key found for this signature in database
GPG Key ID: BD36D48BC9FFC88C

133
state.go
View File

@ -362,95 +362,8 @@ func NewHandshakeState(c Config) (*HandshakeState, error) {
// peer. It is an error to call this method out of sync with the handshake
// pattern.
func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) {
if !s.shouldWrite {
return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
}
if s.msgIdx > len(s.messagePatterns)-1 {
return nil, nil, nil, errors.New("noise: no handshake messages left")
}
if len(payload) > MaxMsgLen {
return nil, nil, nil, errors.New("noise: message is too long")
}
var err error
for _, msg := range s.messagePatterns[s.msgIdx] {
switch msg {
case MessagePatternE:
e, err := s.ss.cs.GenerateKeypair(s.rng)
if err != nil {
return nil, nil, nil, err
}
s.e = e
out = append(out, s.e.Public...)
s.ss.MixHash(s.e.Public)
if len(s.psk) > 0 {
s.ss.MixKey(s.e.Public)
}
case MessagePatternS:
if len(s.s.Public) == 0 {
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
}
out, err = s.ss.EncryptAndHash(out, s.s.Public)
if err != nil {
return nil, nil, nil, err
}
case MessagePatternDHEE:
dh, err := s.ss.cs.DH(s.e.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternDHES:
if s.initiator {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSE:
if s.initiator {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSS:
dh, err := s.ss.cs.DH(s.s.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternPSK:
s.ss.MixKeyAndHash(s.psk)
}
}
s.shouldWrite = false
s.msgIdx++
out, err = s.ss.EncryptAndHash(out, payload)
if err != nil {
return nil, nil, nil, err
}
if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.ss.Split()
return out, cs1, cs2, nil
}
return out, nil, nil, nil
out, _, cs1, cs2, err := s.WriteMessageAndGetPK(out, [][]byte{}, payload)
return out, cs1, cs2, err
}
// WriteMessageAndGetPK appends a handshake message to out. outPK can possibly contain the
@ -459,15 +372,15 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
// one is used for encryption of messages to the remote peer, the other is used
// for decryption of messages from the remote peer. It is an error to call this
// method out of sync with the handshake pattern.
func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte, payload []byte) ([]byte, *CipherState, *CipherState, error) {
func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payload []byte) ([]byte, [][]byte, *CipherState, *CipherState, error) {
if !s.shouldWrite {
return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
return nil, nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
}
if s.msgIdx > len(s.messagePatterns)-1 {
return nil, nil, nil, errors.New("noise: no handshake messages left")
return nil, nil, nil, nil, errors.New("noise: no handshake messages left")
}
if len(payload) > MaxMsgLen {
return nil, nil, nil, errors.New("noise: message is too long")
return nil, nil, nil, nil, errors.New("noise: message is too long")
}
var err error
@ -476,45 +389,43 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
case MessagePatternE:
e, err := s.ss.cs.GenerateKeypair(s.rng)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.e = e
out = append(out, s.e.Public...)
if outNoisePK != nil {
*outNoisePK = append(*outNoisePK, s.e.Public)
}
outPK = append(outPK, s.e.Public)
s.ss.MixHash(s.e.Public)
if len(s.psk) > 0 {
s.ss.MixKey(s.e.Public)
}
case MessagePatternS:
if len(s.s.Public) == 0 {
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
return nil, nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
}
out, err = s.ss.EncryptAndHash(out, s.s.Public)
outPK = append(outPK, out)
if err != nil {
return nil, nil, nil, err
}
if outNoisePK != nil {
*outNoisePK = append(*outNoisePK, out)
return nil, nil, nil, nil, err
}
case MessagePatternDHEE:
dh, err := s.ss.cs.DH(s.e.Private, s.re)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternDHES:
if s.initiator {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.ss.MixKey(dh)
}
@ -522,20 +433,20 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
if s.initiator {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSS:
dh, err := s.ss.cs.DH(s.s.Private, s.rs)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternPSK:
@ -546,15 +457,15 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
s.msgIdx++
out, err = s.ss.EncryptAndHash(out, payload)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.ss.Split()
return out, cs1, cs2, nil
return out, outPK, cs1, cs2, nil
}
return out, nil, nil, nil
return out, outPK, nil, nil, nil
}
// ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.