mirror of
https://github.com/logos-messaging/noise.git
synced 2026-01-04 07:03:08 +00:00
Deduplicate WriteMessage content
This commit is contained in:
parent
da2a9c978d
commit
166d5e87be
133
state.go
133
state.go
@ -362,95 +362,8 @@ func NewHandshakeState(c Config) (*HandshakeState, error) {
|
|||||||
// peer. It is an error to call this method out of sync with the handshake
|
// peer. It is an error to call this method out of sync with the handshake
|
||||||
// pattern.
|
// pattern.
|
||||||
func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) {
|
func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) {
|
||||||
if !s.shouldWrite {
|
out, _, cs1, cs2, err := s.WriteMessageAndGetPK(out, [][]byte{}, payload)
|
||||||
return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
|
return out, cs1, cs2, err
|
||||||
}
|
|
||||||
if s.msgIdx > len(s.messagePatterns)-1 {
|
|
||||||
return nil, nil, nil, errors.New("noise: no handshake messages left")
|
|
||||||
}
|
|
||||||
if len(payload) > MaxMsgLen {
|
|
||||||
return nil, nil, nil, errors.New("noise: message is too long")
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
for _, msg := range s.messagePatterns[s.msgIdx] {
|
|
||||||
switch msg {
|
|
||||||
case MessagePatternE:
|
|
||||||
e, err := s.ss.cs.GenerateKeypair(s.rng)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
s.e = e
|
|
||||||
out = append(out, s.e.Public...)
|
|
||||||
s.ss.MixHash(s.e.Public)
|
|
||||||
if len(s.psk) > 0 {
|
|
||||||
s.ss.MixKey(s.e.Public)
|
|
||||||
}
|
|
||||||
case MessagePatternS:
|
|
||||||
if len(s.s.Public) == 0 {
|
|
||||||
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
|
|
||||||
}
|
|
||||||
out, err = s.ss.EncryptAndHash(out, s.s.Public)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
case MessagePatternDHEE:
|
|
||||||
dh, err := s.ss.cs.DH(s.e.Private, s.re)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
s.ss.MixKey(dh)
|
|
||||||
case MessagePatternDHES:
|
|
||||||
if s.initiator {
|
|
||||||
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
s.ss.MixKey(dh)
|
|
||||||
} else {
|
|
||||||
dh, err := s.ss.cs.DH(s.s.Private, s.re)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
s.ss.MixKey(dh)
|
|
||||||
}
|
|
||||||
case MessagePatternDHSE:
|
|
||||||
if s.initiator {
|
|
||||||
dh, err := s.ss.cs.DH(s.s.Private, s.re)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
s.ss.MixKey(dh)
|
|
||||||
} else {
|
|
||||||
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
s.ss.MixKey(dh)
|
|
||||||
}
|
|
||||||
case MessagePatternDHSS:
|
|
||||||
dh, err := s.ss.cs.DH(s.s.Private, s.rs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
s.ss.MixKey(dh)
|
|
||||||
case MessagePatternPSK:
|
|
||||||
s.ss.MixKeyAndHash(s.psk)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.shouldWrite = false
|
|
||||||
s.msgIdx++
|
|
||||||
out, err = s.ss.EncryptAndHash(out, payload)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.msgIdx >= len(s.messagePatterns) {
|
|
||||||
cs1, cs2 := s.ss.Split()
|
|
||||||
return out, cs1, cs2, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil, nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteMessageAndGetPK appends a handshake message to out. outPK can possibly contain the
|
// WriteMessageAndGetPK appends a handshake message to out. outPK can possibly contain the
|
||||||
@ -459,15 +372,15 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
|
|||||||
// one is used for encryption of messages to the remote peer, the other is used
|
// one is used for encryption of messages to the remote peer, the other is used
|
||||||
// for decryption of messages from the remote peer. It is an error to call this
|
// for decryption of messages from the remote peer. It is an error to call this
|
||||||
// method out of sync with the handshake pattern.
|
// method out of sync with the handshake pattern.
|
||||||
func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte, payload []byte) ([]byte, *CipherState, *CipherState, error) {
|
func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payload []byte) ([]byte, [][]byte, *CipherState, *CipherState, error) {
|
||||||
if !s.shouldWrite {
|
if !s.shouldWrite {
|
||||||
return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
|
return nil, nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
|
||||||
}
|
}
|
||||||
if s.msgIdx > len(s.messagePatterns)-1 {
|
if s.msgIdx > len(s.messagePatterns)-1 {
|
||||||
return nil, nil, nil, errors.New("noise: no handshake messages left")
|
return nil, nil, nil, nil, errors.New("noise: no handshake messages left")
|
||||||
}
|
}
|
||||||
if len(payload) > MaxMsgLen {
|
if len(payload) > MaxMsgLen {
|
||||||
return nil, nil, nil, errors.New("noise: message is too long")
|
return nil, nil, nil, nil, errors.New("noise: message is too long")
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
@ -476,45 +389,43 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
|
|||||||
case MessagePatternE:
|
case MessagePatternE:
|
||||||
e, err := s.ss.cs.GenerateKeypair(s.rng)
|
e, err := s.ss.cs.GenerateKeypair(s.rng)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
s.e = e
|
s.e = e
|
||||||
out = append(out, s.e.Public...)
|
out = append(out, s.e.Public...)
|
||||||
if outNoisePK != nil {
|
outPK = append(outPK, s.e.Public)
|
||||||
*outNoisePK = append(*outNoisePK, s.e.Public)
|
|
||||||
}
|
|
||||||
s.ss.MixHash(s.e.Public)
|
s.ss.MixHash(s.e.Public)
|
||||||
if len(s.psk) > 0 {
|
if len(s.psk) > 0 {
|
||||||
s.ss.MixKey(s.e.Public)
|
s.ss.MixKey(s.e.Public)
|
||||||
}
|
}
|
||||||
case MessagePatternS:
|
case MessagePatternS:
|
||||||
if len(s.s.Public) == 0 {
|
if len(s.s.Public) == 0 {
|
||||||
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
|
return nil, nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
|
||||||
}
|
}
|
||||||
out, err = s.ss.EncryptAndHash(out, s.s.Public)
|
out, err = s.ss.EncryptAndHash(out, s.s.Public)
|
||||||
|
outPK = append(outPK, out)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
|
||||||
if outNoisePK != nil {
|
|
||||||
*outNoisePK = append(*outNoisePK, out)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case MessagePatternDHEE:
|
case MessagePatternDHEE:
|
||||||
dh, err := s.ss.cs.DH(s.e.Private, s.re)
|
dh, err := s.ss.cs.DH(s.e.Private, s.re)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
s.ss.MixKey(dh)
|
s.ss.MixKey(dh)
|
||||||
case MessagePatternDHES:
|
case MessagePatternDHES:
|
||||||
if s.initiator {
|
if s.initiator {
|
||||||
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
|
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
s.ss.MixKey(dh)
|
s.ss.MixKey(dh)
|
||||||
} else {
|
} else {
|
||||||
dh, err := s.ss.cs.DH(s.s.Private, s.re)
|
dh, err := s.ss.cs.DH(s.s.Private, s.re)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
s.ss.MixKey(dh)
|
s.ss.MixKey(dh)
|
||||||
}
|
}
|
||||||
@ -522,20 +433,20 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
|
|||||||
if s.initiator {
|
if s.initiator {
|
||||||
dh, err := s.ss.cs.DH(s.s.Private, s.re)
|
dh, err := s.ss.cs.DH(s.s.Private, s.re)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
s.ss.MixKey(dh)
|
s.ss.MixKey(dh)
|
||||||
} else {
|
} else {
|
||||||
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
|
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
s.ss.MixKey(dh)
|
s.ss.MixKey(dh)
|
||||||
}
|
}
|
||||||
case MessagePatternDHSS:
|
case MessagePatternDHSS:
|
||||||
dh, err := s.ss.cs.DH(s.s.Private, s.rs)
|
dh, err := s.ss.cs.DH(s.s.Private, s.rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
s.ss.MixKey(dh)
|
s.ss.MixKey(dh)
|
||||||
case MessagePatternPSK:
|
case MessagePatternPSK:
|
||||||
@ -546,15 +457,15 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
|
|||||||
s.msgIdx++
|
s.msgIdx++
|
||||||
out, err = s.ss.EncryptAndHash(out, payload)
|
out, err = s.ss.EncryptAndHash(out, payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.msgIdx >= len(s.messagePatterns) {
|
if s.msgIdx >= len(s.messagePatterns) {
|
||||||
cs1, cs2 := s.ss.Split()
|
cs1, cs2 := s.ss.Split()
|
||||||
return out, cs1, cs2, nil
|
return out, outPK, cs1, cs2, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return out, nil, nil, nil
|
return out, outPK, nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.
|
// ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user