Expose Hash, RS, H and add ad to Encrypt funcs

This commit is contained in:
Richard Ramos 2022-12-14 12:10:23 -04:00
parent b14b0d0806
commit 815c0ed47c
No known key found for this signature in database
GPG Key ID: BD36D48BC9FFC88C
1 changed files with 30 additions and 10 deletions

View File

@ -10,6 +10,7 @@ import (
"crypto/rand" "crypto/rand"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"math" "math"
) )
@ -149,12 +150,16 @@ func (s *symmetricState) MixKeyAndHash(data []byte) {
s.hasK = true s.hasK = true
} }
func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) { // Note that by setting extraAd, it is possible to pass extra additional data that will be concatenated to the ad specified by Noise (can be used to authenticate messageNametag)
func (s *symmetricState) EncryptAndHash(out, plaintext []byte, extraAd ...byte) ([]byte, error) {
if !s.hasK { if !s.hasK {
s.MixHash(plaintext) s.MixHash(plaintext)
return append(out, plaintext...), nil return append(out, plaintext...), nil
} }
ciphertext, err := s.Encrypt(out, s.h, plaintext)
ad := append([]byte(nil), s.h...)
ad = append(ad, extraAd...)
ciphertext, err := s.Encrypt(out, ad, plaintext)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -162,12 +167,15 @@ func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) {
return ciphertext, nil return ciphertext, nil
} }
func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) { func (s *symmetricState) DecryptAndHash(out, data []byte, extraAd ...byte) ([]byte, error) {
if !s.hasK { if !s.hasK {
s.MixHash(data) s.MixHash(data)
return append(out, data...), nil return append(out, data...), nil
} }
plaintext, err := s.Decrypt(out, s.h, data)
ad := append([]byte(nil), s.h...)
ad = append(ad, extraAd...)
plaintext, err := s.Decrypt(out, ad, data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -355,14 +363,22 @@ func NewHandshakeState(c Config) (*HandshakeState, error) {
return hs, nil return hs, nil
} }
func (s *HandshakeState) H() []byte {
return append([]byte(nil), s.ss.h...)
}
func (s *HandshakeState) RS() []byte {
return append([]byte(nil), s.rs...)
}
// WriteMessage appends a handshake message to out. The message will include the // WriteMessage appends a handshake message to out. The message will include the
// optional payload if provided. If the handshake is completed by the call, two // optional payload if provided. If the handshake is completed by the call, two
// CipherStates will be returned, one is used for encryption of messages to the // CipherStates will be returned, one is used for encryption of messages to the
// remote peer, the other is used for decryption of messages from the remote // remote peer, the other is used for decryption of messages from the remote
// peer. It is an error to call this method out of sync with the handshake // peer. It is an error to call this method out of sync with the handshake
// pattern. // pattern.
func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) { func (s *HandshakeState) WriteMessage(out, payload []byte, extraAd ...byte) ([]byte, *CipherState, *CipherState, error) {
out, _, cs1, cs2, err := s.WriteMessageAndGetPK(out, [][]byte{}, payload) out, _, cs1, cs2, err := s.WriteMessageAndGetPK(out, [][]byte{}, payload, extraAd)
return out, cs1, cs2, err return out, cs1, cs2, err
} }
@ -372,7 +388,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
// one is used for encryption of messages to the remote peer, the other is used // one is used for encryption of messages to the remote peer, the other is used
// for decryption of messages from the remote peer. It is an error to call this // for decryption of messages from the remote peer. It is an error to call this
// method out of sync with the handshake pattern. // method out of sync with the handshake pattern.
func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payload []byte) ([]byte, [][]byte, *CipherState, *CipherState, error) { func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payload []byte, extraAd []byte) ([]byte, [][]byte, *CipherState, *CipherState, error) {
if !s.shouldWrite { if !s.shouldWrite {
return nil, nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage") return nil, nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
} }
@ -455,7 +471,7 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payloa
} }
s.shouldWrite = false s.shouldWrite = false
s.msgIdx++ s.msgIdx++
out, err = s.ss.EncryptAndHash(out, payload) out, err = s.ss.EncryptAndHash(out, payload, extraAd...)
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, nil, err
} }
@ -468,6 +484,10 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payloa
return out, outPK, nil, nil, nil return out, outPK, nil, nil, nil
} }
func (s *HandshakeState) Hash() hash.Hash {
return s.ss.cs.Hash()
}
// ErrShortMessage is returned by ReadMessage if a message is not as long as it should be. // ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.
var ErrShortMessage = errors.New("noise: message is too short") var ErrShortMessage = errors.New("noise: message is too short")
@ -476,7 +496,7 @@ var ErrShortMessage = errors.New("noise: message is too short")
// will be returned, one is used for encryption of messages to the remote peer, // will be returned, one is used for encryption of messages to the remote peer,
// the other is used for decryption of messages from the remote peer. It is an // the other is used for decryption of messages from the remote peer. It is an
// error to call this method out of sync with the handshake pattern. // error to call this method out of sync with the handshake pattern.
func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, *CipherState, error) { func (s *HandshakeState) ReadMessage(out, message []byte, extraAd ...byte) ([]byte, *CipherState, *CipherState, error) {
if s.shouldWrite { if s.shouldWrite {
return nil, nil, nil, errors.New("noise: unexpected call to ReadMessage should be WriteMessage") return nil, nil, nil, errors.New("noise: unexpected call to ReadMessage should be WriteMessage")
} }
@ -568,7 +588,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
s.ss.MixKeyAndHash(s.psk) s.ss.MixKeyAndHash(s.psk)
} }
} }
out, err = s.ss.DecryptAndHash(out, message) out, err = s.ss.DecryptAndHash(out, message, extraAd...)
if err != nil { if err != nil {
s.ss.Rollback() s.ss.Rollback()
if rsSet { if rsSet {