From 166d5e87befae928b8f5f1bfbf44da9cad02e940 Mon Sep 17 00:00:00 2001 From: Richard Ramos Date: Sun, 23 Oct 2022 09:01:24 -0400 Subject: [PATCH] Deduplicate WriteMessage content --- state.go | 133 +++++++++---------------------------------------------- 1 file changed, 22 insertions(+), 111 deletions(-) diff --git a/state.go b/state.go index 0995ec3..feefde8 100644 --- a/state.go +++ b/state.go @@ -362,95 +362,8 @@ func NewHandshakeState(c Config) (*HandshakeState, error) { // peer. It is an error to call this method out of sync with the handshake // pattern. func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) { - if !s.shouldWrite { - return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage") - } - if s.msgIdx > len(s.messagePatterns)-1 { - return nil, nil, nil, errors.New("noise: no handshake messages left") - } - if len(payload) > MaxMsgLen { - return nil, nil, nil, errors.New("noise: message is too long") - } - - var err error - for _, msg := range s.messagePatterns[s.msgIdx] { - switch msg { - case MessagePatternE: - e, err := s.ss.cs.GenerateKeypair(s.rng) - if err != nil { - return nil, nil, nil, err - } - s.e = e - out = append(out, s.e.Public...) - s.ss.MixHash(s.e.Public) - if len(s.psk) > 0 { - s.ss.MixKey(s.e.Public) - } - case MessagePatternS: - if len(s.s.Public) == 0 { - return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil") - } - out, err = s.ss.EncryptAndHash(out, s.s.Public) - if err != nil { - return nil, nil, nil, err - } - case MessagePatternDHEE: - dh, err := s.ss.cs.DH(s.e.Private, s.re) - if err != nil { - return nil, nil, nil, err - } - s.ss.MixKey(dh) - case MessagePatternDHES: - if s.initiator { - dh, err := s.ss.cs.DH(s.e.Private, s.rs) - if err != nil { - return nil, nil, nil, err - } - s.ss.MixKey(dh) - } else { - dh, err := s.ss.cs.DH(s.s.Private, s.re) - if err != nil { - return nil, nil, nil, err - } - s.ss.MixKey(dh) - } - case MessagePatternDHSE: - if s.initiator { - dh, err := s.ss.cs.DH(s.s.Private, s.re) - if err != nil { - return nil, nil, nil, err - } - s.ss.MixKey(dh) - } else { - dh, err := s.ss.cs.DH(s.e.Private, s.rs) - if err != nil { - return nil, nil, nil, err - } - s.ss.MixKey(dh) - } - case MessagePatternDHSS: - dh, err := s.ss.cs.DH(s.s.Private, s.rs) - if err != nil { - return nil, nil, nil, err - } - s.ss.MixKey(dh) - case MessagePatternPSK: - s.ss.MixKeyAndHash(s.psk) - } - } - s.shouldWrite = false - s.msgIdx++ - out, err = s.ss.EncryptAndHash(out, payload) - if err != nil { - return nil, nil, nil, err - } - - if s.msgIdx >= len(s.messagePatterns) { - cs1, cs2 := s.ss.Split() - return out, cs1, cs2, nil - } - - return out, nil, nil, nil + out, _, cs1, cs2, err := s.WriteMessageAndGetPK(out, [][]byte{}, payload) + return out, cs1, cs2, err } // WriteMessageAndGetPK appends a handshake message to out. outPK can possibly contain the @@ -459,15 +372,15 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState // one is used for encryption of messages to the remote peer, the other is used // for decryption of messages from the remote peer. It is an error to call this // method out of sync with the handshake pattern. -func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte, payload []byte) ([]byte, *CipherState, *CipherState, error) { +func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payload []byte) ([]byte, [][]byte, *CipherState, *CipherState, error) { if !s.shouldWrite { - return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage") + return nil, nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage") } if s.msgIdx > len(s.messagePatterns)-1 { - return nil, nil, nil, errors.New("noise: no handshake messages left") + return nil, nil, nil, nil, errors.New("noise: no handshake messages left") } if len(payload) > MaxMsgLen { - return nil, nil, nil, errors.New("noise: message is too long") + return nil, nil, nil, nil, errors.New("noise: message is too long") } var err error @@ -476,45 +389,43 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte, case MessagePatternE: e, err := s.ss.cs.GenerateKeypair(s.rng) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } s.e = e out = append(out, s.e.Public...) - if outNoisePK != nil { - *outNoisePK = append(*outNoisePK, s.e.Public) - } + outPK = append(outPK, s.e.Public) + s.ss.MixHash(s.e.Public) if len(s.psk) > 0 { s.ss.MixKey(s.e.Public) } case MessagePatternS: if len(s.s.Public) == 0 { - return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil") + return nil, nil, nil, nil, errors.New("noise: invalid state, s.Public is nil") } out, err = s.ss.EncryptAndHash(out, s.s.Public) + outPK = append(outPK, out) if err != nil { - return nil, nil, nil, err - } - if outNoisePK != nil { - *outNoisePK = append(*outNoisePK, out) + return nil, nil, nil, nil, err } + case MessagePatternDHEE: dh, err := s.ss.cs.DH(s.e.Private, s.re) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } s.ss.MixKey(dh) case MessagePatternDHES: if s.initiator { dh, err := s.ss.cs.DH(s.e.Private, s.rs) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } s.ss.MixKey(dh) } else { dh, err := s.ss.cs.DH(s.s.Private, s.re) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } s.ss.MixKey(dh) } @@ -522,20 +433,20 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte, if s.initiator { dh, err := s.ss.cs.DH(s.s.Private, s.re) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } s.ss.MixKey(dh) } else { dh, err := s.ss.cs.DH(s.e.Private, s.rs) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } s.ss.MixKey(dh) } case MessagePatternDHSS: dh, err := s.ss.cs.DH(s.s.Private, s.rs) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } s.ss.MixKey(dh) case MessagePatternPSK: @@ -546,15 +457,15 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte, s.msgIdx++ out, err = s.ss.EncryptAndHash(out, payload) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, nil, err } if s.msgIdx >= len(s.messagePatterns) { cs1, cs2 := s.ss.Split() - return out, cs1, cs2, nil + return out, outPK, cs1, cs2, nil } - return out, nil, nil, nil + return out, outPK, nil, nil, nil } // ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.