diff --git a/state.go b/state.go index 985eea5..0995ec3 100644 --- a/state.go +++ b/state.go @@ -453,6 +453,110 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState return out, nil, nil, nil } +// WriteMessageAndGetPK appends a handshake message to out. outPK can possibly contain the +// party public keys. The message will include the optional payload if provided. +// If the handshake is completed by the call, two CipherStates will be returned, +// one is used for encryption of messages to the remote peer, the other is used +// for decryption of messages from the remote peer. It is an error to call this +// method out of sync with the handshake pattern. +func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte, payload []byte) ([]byte, *CipherState, *CipherState, error) { + if !s.shouldWrite { + return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage") + } + if s.msgIdx > len(s.messagePatterns)-1 { + return nil, nil, nil, errors.New("noise: no handshake messages left") + } + if len(payload) > MaxMsgLen { + return nil, nil, nil, errors.New("noise: message is too long") + } + + var err error + for _, msg := range s.messagePatterns[s.msgIdx] { + switch msg { + case MessagePatternE: + e, err := s.ss.cs.GenerateKeypair(s.rng) + if err != nil { + return nil, nil, nil, err + } + s.e = e + out = append(out, s.e.Public...) + if outNoisePK != nil { + *outNoisePK = append(*outNoisePK, s.e.Public) + } + s.ss.MixHash(s.e.Public) + if len(s.psk) > 0 { + s.ss.MixKey(s.e.Public) + } + case MessagePatternS: + if len(s.s.Public) == 0 { + return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil") + } + out, err = s.ss.EncryptAndHash(out, s.s.Public) + if err != nil { + return nil, nil, nil, err + } + if outNoisePK != nil { + *outNoisePK = append(*outNoisePK, out) + } + case MessagePatternDHEE: + dh, err := s.ss.cs.DH(s.e.Private, s.re) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) + case MessagePatternDHES: + if s.initiator { + dh, err := s.ss.cs.DH(s.e.Private, s.rs) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) + } else { + dh, err := s.ss.cs.DH(s.s.Private, s.re) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) + } + case MessagePatternDHSE: + if s.initiator { + dh, err := s.ss.cs.DH(s.s.Private, s.re) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) + } else { + dh, err := s.ss.cs.DH(s.e.Private, s.rs) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) + } + case MessagePatternDHSS: + dh, err := s.ss.cs.DH(s.s.Private, s.rs) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) + case MessagePatternPSK: + s.ss.MixKeyAndHash(s.psk) + } + } + s.shouldWrite = false + s.msgIdx++ + out, err = s.ss.EncryptAndHash(out, payload) + if err != nil { + return nil, nil, nil, err + } + + if s.msgIdx >= len(s.messagePatterns) { + cs1, cs2 := s.ss.Split() + return out, cs1, cs2, nil + } + + return out, nil, nil, nil +} + // ErrShortMessage is returned by ReadMessage if a message is not as long as it should be. var ErrShortMessage = errors.New("noise: message is too short")