Compare commits

...

2 Commits

Author SHA1 Message Date
Richard Ramos
4fcac0b407
refactor: deduplicate WriteMessage content 2022-10-23 09:02:31 -04:00
Richard Ramos
d794400c8f
chore: upgrade to go 1.17 and rename module for easier integration 2022-10-23 09:02:10 -04:00
2 changed files with 30 additions and 113 deletions

10
go.mod
View File

@ -1,8 +1,14 @@
module github.com/flynn/noise module github.com/status-im/noise
go 1.16 go 1.17
require ( require (
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c
) )
require (
github.com/kr/pretty v0.2.1 // indirect
github.com/kr/text v0.1.0 // indirect
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 // indirect
)

133
state.go
View File

@ -362,95 +362,8 @@ func NewHandshakeState(c Config) (*HandshakeState, error) {
// peer. It is an error to call this method out of sync with the handshake // peer. It is an error to call this method out of sync with the handshake
// pattern. // pattern.
func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) { func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) {
if !s.shouldWrite { out, _, cs1, cs2, err := s.WriteMessageAndGetPK(out, [][]byte{}, payload)
return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage") return out, cs1, cs2, err
}
if s.msgIdx > len(s.messagePatterns)-1 {
return nil, nil, nil, errors.New("noise: no handshake messages left")
}
if len(payload) > MaxMsgLen {
return nil, nil, nil, errors.New("noise: message is too long")
}
var err error
for _, msg := range s.messagePatterns[s.msgIdx] {
switch msg {
case MessagePatternE:
e, err := s.ss.cs.GenerateKeypair(s.rng)
if err != nil {
return nil, nil, nil, err
}
s.e = e
out = append(out, s.e.Public...)
s.ss.MixHash(s.e.Public)
if len(s.psk) > 0 {
s.ss.MixKey(s.e.Public)
}
case MessagePatternS:
if len(s.s.Public) == 0 {
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
}
out, err = s.ss.EncryptAndHash(out, s.s.Public)
if err != nil {
return nil, nil, nil, err
}
case MessagePatternDHEE:
dh, err := s.ss.cs.DH(s.e.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternDHES:
if s.initiator {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSE:
if s.initiator {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSS:
dh, err := s.ss.cs.DH(s.s.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternPSK:
s.ss.MixKeyAndHash(s.psk)
}
}
s.shouldWrite = false
s.msgIdx++
out, err = s.ss.EncryptAndHash(out, payload)
if err != nil {
return nil, nil, nil, err
}
if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.ss.Split()
return out, cs1, cs2, nil
}
return out, nil, nil, nil
} }
// WriteMessageAndGetPK appends a handshake message to out. outPK can possibly contain the // WriteMessageAndGetPK appends a handshake message to out. outPK can possibly contain the
@ -459,15 +372,15 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
// one is used for encryption of messages to the remote peer, the other is used // one is used for encryption of messages to the remote peer, the other is used
// for decryption of messages from the remote peer. It is an error to call this // for decryption of messages from the remote peer. It is an error to call this
// method out of sync with the handshake pattern. // method out of sync with the handshake pattern.
func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte, payload []byte) ([]byte, *CipherState, *CipherState, error) { func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payload []byte) ([]byte, [][]byte, *CipherState, *CipherState, error) {
if !s.shouldWrite { if !s.shouldWrite {
return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage") return nil, nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
} }
if s.msgIdx > len(s.messagePatterns)-1 { if s.msgIdx > len(s.messagePatterns)-1 {
return nil, nil, nil, errors.New("noise: no handshake messages left") return nil, nil, nil, nil, errors.New("noise: no handshake messages left")
} }
if len(payload) > MaxMsgLen { if len(payload) > MaxMsgLen {
return nil, nil, nil, errors.New("noise: message is too long") return nil, nil, nil, nil, errors.New("noise: message is too long")
} }
var err error var err error
@ -476,45 +389,43 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
case MessagePatternE: case MessagePatternE:
e, err := s.ss.cs.GenerateKeypair(s.rng) e, err := s.ss.cs.GenerateKeypair(s.rng)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.e = e s.e = e
out = append(out, s.e.Public...) out = append(out, s.e.Public...)
if outNoisePK != nil { outPK = append(outPK, s.e.Public)
*outNoisePK = append(*outNoisePK, s.e.Public)
}
s.ss.MixHash(s.e.Public) s.ss.MixHash(s.e.Public)
if len(s.psk) > 0 { if len(s.psk) > 0 {
s.ss.MixKey(s.e.Public) s.ss.MixKey(s.e.Public)
} }
case MessagePatternS: case MessagePatternS:
if len(s.s.Public) == 0 { if len(s.s.Public) == 0 {
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil") return nil, nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
} }
out, err = s.ss.EncryptAndHash(out, s.s.Public) out, err = s.ss.EncryptAndHash(out, s.s.Public)
outPK = append(outPK, out)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
}
if outNoisePK != nil {
*outNoisePK = append(*outNoisePK, out)
} }
case MessagePatternDHEE: case MessagePatternDHEE:
dh, err := s.ss.cs.DH(s.e.Private, s.re) dh, err := s.ss.cs.DH(s.e.Private, s.re)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.ss.MixKey(dh) s.ss.MixKey(dh)
case MessagePatternDHES: case MessagePatternDHES:
if s.initiator { if s.initiator {
dh, err := s.ss.cs.DH(s.e.Private, s.rs) dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.ss.MixKey(dh) s.ss.MixKey(dh)
} else { } else {
dh, err := s.ss.cs.DH(s.s.Private, s.re) dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.ss.MixKey(dh) s.ss.MixKey(dh)
} }
@ -522,20 +433,20 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
if s.initiator { if s.initiator {
dh, err := s.ss.cs.DH(s.s.Private, s.re) dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.ss.MixKey(dh) s.ss.MixKey(dh)
} else { } else {
dh, err := s.ss.cs.DH(s.e.Private, s.rs) dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.ss.MixKey(dh) s.ss.MixKey(dh)
} }
case MessagePatternDHSS: case MessagePatternDHSS:
dh, err := s.ss.cs.DH(s.s.Private, s.rs) dh, err := s.ss.cs.DH(s.s.Private, s.rs)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.ss.MixKey(dh) s.ss.MixKey(dh)
case MessagePatternPSK: case MessagePatternPSK:
@ -546,15 +457,15 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
s.msgIdx++ s.msgIdx++
out, err = s.ss.EncryptAndHash(out, payload) out, err = s.ss.EncryptAndHash(out, payload)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
if s.msgIdx >= len(s.messagePatterns) { if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.ss.Split() cs1, cs2 := s.ss.Split()
return out, cs1, cs2, nil return out, outPK, cs1, cs2, nil
} }
return out, nil, nil, nil return out, outPK, nil, nil, nil
} }
// ErrShortMessage is returned by ReadMessage if a message is not as long as it should be. // ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.