Compare commits

...

2 Commits

Author SHA1 Message Date
Richard Ramos
4fcac0b407
refactor: deduplicate WriteMessage content 2022-10-23 09:02:31 -04:00
Richard Ramos
d794400c8f
chore: upgrade to go 1.17 and rename module for easier integration 2022-10-23 09:02:10 -04:00
2 changed files with 30 additions and 113 deletions

10
go.mod
View File

@ -1,8 +1,14 @@
module github.com/flynn/noise
module github.com/status-im/noise
go 1.16
go 1.17
require (
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c
)
require (
github.com/kr/pretty v0.2.1 // indirect
github.com/kr/text v0.1.0 // indirect
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 // indirect
)

133
state.go
View File

@ -362,95 +362,8 @@ func NewHandshakeState(c Config) (*HandshakeState, error) {
// peer. It is an error to call this method out of sync with the handshake
// pattern.
func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) {
if !s.shouldWrite {
return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
}
if s.msgIdx > len(s.messagePatterns)-1 {
return nil, nil, nil, errors.New("noise: no handshake messages left")
}
if len(payload) > MaxMsgLen {
return nil, nil, nil, errors.New("noise: message is too long")
}
var err error
for _, msg := range s.messagePatterns[s.msgIdx] {
switch msg {
case MessagePatternE:
e, err := s.ss.cs.GenerateKeypair(s.rng)
if err != nil {
return nil, nil, nil, err
}
s.e = e
out = append(out, s.e.Public...)
s.ss.MixHash(s.e.Public)
if len(s.psk) > 0 {
s.ss.MixKey(s.e.Public)
}
case MessagePatternS:
if len(s.s.Public) == 0 {
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
}
out, err = s.ss.EncryptAndHash(out, s.s.Public)
if err != nil {
return nil, nil, nil, err
}
case MessagePatternDHEE:
dh, err := s.ss.cs.DH(s.e.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternDHES:
if s.initiator {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSE:
if s.initiator {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSS:
dh, err := s.ss.cs.DH(s.s.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternPSK:
s.ss.MixKeyAndHash(s.psk)
}
}
s.shouldWrite = false
s.msgIdx++
out, err = s.ss.EncryptAndHash(out, payload)
if err != nil {
return nil, nil, nil, err
}
if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.ss.Split()
return out, cs1, cs2, nil
}
return out, nil, nil, nil
out, _, cs1, cs2, err := s.WriteMessageAndGetPK(out, [][]byte{}, payload)
return out, cs1, cs2, err
}
// WriteMessageAndGetPK appends a handshake message to out. outPK can possibly contain the
@ -459,15 +372,15 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
// one is used for encryption of messages to the remote peer, the other is used
// for decryption of messages from the remote peer. It is an error to call this
// method out of sync with the handshake pattern.
func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte, payload []byte) ([]byte, *CipherState, *CipherState, error) {
func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payload []byte) ([]byte, [][]byte, *CipherState, *CipherState, error) {
if !s.shouldWrite {
return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
return nil, nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
}
if s.msgIdx > len(s.messagePatterns)-1 {
return nil, nil, nil, errors.New("noise: no handshake messages left")
return nil, nil, nil, nil, errors.New("noise: no handshake messages left")
}
if len(payload) > MaxMsgLen {
return nil, nil, nil, errors.New("noise: message is too long")
return nil, nil, nil, nil, errors.New("noise: message is too long")
}
var err error
@ -476,45 +389,43 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
case MessagePatternE:
e, err := s.ss.cs.GenerateKeypair(s.rng)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.e = e
out = append(out, s.e.Public...)
if outNoisePK != nil {
*outNoisePK = append(*outNoisePK, s.e.Public)
}
outPK = append(outPK, s.e.Public)
s.ss.MixHash(s.e.Public)
if len(s.psk) > 0 {
s.ss.MixKey(s.e.Public)
}
case MessagePatternS:
if len(s.s.Public) == 0 {
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
return nil, nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
}
out, err = s.ss.EncryptAndHash(out, s.s.Public)
outPK = append(outPK, out)
if err != nil {
return nil, nil, nil, err
}
if outNoisePK != nil {
*outNoisePK = append(*outNoisePK, out)
return nil, nil, nil, nil, err
}
case MessagePatternDHEE:
dh, err := s.ss.cs.DH(s.e.Private, s.re)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternDHES:
if s.initiator {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.ss.MixKey(dh)
}
@ -522,20 +433,20 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
if s.initiator {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSS:
dh, err := s.ss.cs.DH(s.s.Private, s.rs)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternPSK:
@ -546,15 +457,15 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
s.msgIdx++
out, err = s.ss.EncryptAndHash(out, payload)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}
if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.ss.Split()
return out, cs1, cs2, nil
return out, outPK, cs1, cs2, nil
}
return out, nil, nil, nil
return out, outPK, nil, nil, nil
}
// ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.