From 815c0ed47c1924dce138fcec7bb6bc970578402e Mon Sep 17 00:00:00 2001 From: Richard Ramos Date: Wed, 14 Dec 2022 12:10:23 -0400 Subject: [PATCH] Expose Hash, RS, H and add ad to Encrypt funcs --- state.go | 40 ++++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/state.go b/state.go index feefde8..ec08782 100644 --- a/state.go +++ b/state.go @@ -10,6 +10,7 @@ import ( "crypto/rand" "errors" "fmt" + "hash" "io" "math" ) @@ -149,12 +150,16 @@ func (s *symmetricState) MixKeyAndHash(data []byte) { s.hasK = true } -func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) { +// Note that by setting extraAd, it is possible to pass extra additional data that will be concatenated to the ad specified by Noise (can be used to authenticate messageNametag) +func (s *symmetricState) EncryptAndHash(out, plaintext []byte, extraAd ...byte) ([]byte, error) { if !s.hasK { s.MixHash(plaintext) return append(out, plaintext...), nil } - ciphertext, err := s.Encrypt(out, s.h, plaintext) + + ad := append([]byte(nil), s.h...) + ad = append(ad, extraAd...) + ciphertext, err := s.Encrypt(out, ad, plaintext) if err != nil { return nil, err } @@ -162,12 +167,15 @@ func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) { return ciphertext, nil } -func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) { +func (s *symmetricState) DecryptAndHash(out, data []byte, extraAd ...byte) ([]byte, error) { if !s.hasK { s.MixHash(data) return append(out, data...), nil } - plaintext, err := s.Decrypt(out, s.h, data) + + ad := append([]byte(nil), s.h...) + ad = append(ad, extraAd...) + plaintext, err := s.Decrypt(out, ad, data) if err != nil { return nil, err } @@ -355,14 +363,22 @@ func NewHandshakeState(c Config) (*HandshakeState, error) { return hs, nil } +func (s *HandshakeState) H() []byte { + return append([]byte(nil), s.ss.h...) +} + +func (s *HandshakeState) RS() []byte { + return append([]byte(nil), s.rs...) +} + // WriteMessage appends a handshake message to out. The message will include the // optional payload if provided. If the handshake is completed by the call, two // CipherStates will be returned, one is used for encryption of messages to the // remote peer, the other is used for decryption of messages from the remote // peer. It is an error to call this method out of sync with the handshake // pattern. -func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) { - out, _, cs1, cs2, err := s.WriteMessageAndGetPK(out, [][]byte{}, payload) +func (s *HandshakeState) WriteMessage(out, payload []byte, extraAd ...byte) ([]byte, *CipherState, *CipherState, error) { + out, _, cs1, cs2, err := s.WriteMessageAndGetPK(out, [][]byte{}, payload, extraAd) return out, cs1, cs2, err } @@ -372,7 +388,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState // one is used for encryption of messages to the remote peer, the other is used // for decryption of messages from the remote peer. It is an error to call this // method out of sync with the handshake pattern. -func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payload []byte) ([]byte, [][]byte, *CipherState, *CipherState, error) { +func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payload []byte, extraAd []byte) ([]byte, [][]byte, *CipherState, *CipherState, error) { if !s.shouldWrite { return nil, nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage") } @@ -455,7 +471,7 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payloa } s.shouldWrite = false s.msgIdx++ - out, err = s.ss.EncryptAndHash(out, payload) + out, err = s.ss.EncryptAndHash(out, payload, extraAd...) if err != nil { return nil, nil, nil, nil, err } @@ -468,6 +484,10 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payloa return out, outPK, nil, nil, nil } +func (s *HandshakeState) Hash() hash.Hash { + return s.ss.cs.Hash() +} + // ErrShortMessage is returned by ReadMessage if a message is not as long as it should be. var ErrShortMessage = errors.New("noise: message is too short") @@ -476,7 +496,7 @@ var ErrShortMessage = errors.New("noise: message is too short") // will be returned, one is used for encryption of messages to the remote peer, // the other is used for decryption of messages from the remote peer. It is an // error to call this method out of sync with the handshake pattern. -func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, *CipherState, error) { +func (s *HandshakeState) ReadMessage(out, message []byte, extraAd ...byte) ([]byte, *CipherState, *CipherState, error) { if s.shouldWrite { return nil, nil, nil, errors.New("noise: unexpected call to ReadMessage should be WriteMessage") } @@ -568,7 +588,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, s.ss.MixKeyAndHash(s.psk) } } - out, err = s.ss.DecryptAndHash(out, message) + out, err = s.ss.DecryptAndHash(out, message, extraAd...) if err != nil { s.ss.Rollback() if rsSet {