mirror of https://github.com/waku-org/noise.git
Expose Hash, RS, H and add ad to Encrypt funcs
This commit is contained in:
parent
b14b0d0806
commit
815c0ed47c
40
state.go
40
state.go
|
@ -10,6 +10,7 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
)
|
)
|
||||||
|
@ -149,12 +150,16 @@ func (s *symmetricState) MixKeyAndHash(data []byte) {
|
||||||
s.hasK = true
|
s.hasK = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) {
|
// Note that by setting extraAd, it is possible to pass extra additional data that will be concatenated to the ad specified by Noise (can be used to authenticate messageNametag)
|
||||||
|
func (s *symmetricState) EncryptAndHash(out, plaintext []byte, extraAd ...byte) ([]byte, error) {
|
||||||
if !s.hasK {
|
if !s.hasK {
|
||||||
s.MixHash(plaintext)
|
s.MixHash(plaintext)
|
||||||
return append(out, plaintext...), nil
|
return append(out, plaintext...), nil
|
||||||
}
|
}
|
||||||
ciphertext, err := s.Encrypt(out, s.h, plaintext)
|
|
||||||
|
ad := append([]byte(nil), s.h...)
|
||||||
|
ad = append(ad, extraAd...)
|
||||||
|
ciphertext, err := s.Encrypt(out, ad, plaintext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -162,12 +167,15 @@ func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) {
|
||||||
return ciphertext, nil
|
return ciphertext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) {
|
func (s *symmetricState) DecryptAndHash(out, data []byte, extraAd ...byte) ([]byte, error) {
|
||||||
if !s.hasK {
|
if !s.hasK {
|
||||||
s.MixHash(data)
|
s.MixHash(data)
|
||||||
return append(out, data...), nil
|
return append(out, data...), nil
|
||||||
}
|
}
|
||||||
plaintext, err := s.Decrypt(out, s.h, data)
|
|
||||||
|
ad := append([]byte(nil), s.h...)
|
||||||
|
ad = append(ad, extraAd...)
|
||||||
|
plaintext, err := s.Decrypt(out, ad, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -355,14 +363,22 @@ func NewHandshakeState(c Config) (*HandshakeState, error) {
|
||||||
return hs, nil
|
return hs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *HandshakeState) H() []byte {
|
||||||
|
return append([]byte(nil), s.ss.h...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HandshakeState) RS() []byte {
|
||||||
|
return append([]byte(nil), s.rs...)
|
||||||
|
}
|
||||||
|
|
||||||
// WriteMessage appends a handshake message to out. The message will include the
|
// WriteMessage appends a handshake message to out. The message will include the
|
||||||
// optional payload if provided. If the handshake is completed by the call, two
|
// optional payload if provided. If the handshake is completed by the call, two
|
||||||
// CipherStates will be returned, one is used for encryption of messages to the
|
// CipherStates will be returned, one is used for encryption of messages to the
|
||||||
// remote peer, the other is used for decryption of messages from the remote
|
// remote peer, the other is used for decryption of messages from the remote
|
||||||
// peer. It is an error to call this method out of sync with the handshake
|
// peer. It is an error to call this method out of sync with the handshake
|
||||||
// pattern.
|
// pattern.
|
||||||
func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) {
|
func (s *HandshakeState) WriteMessage(out, payload []byte, extraAd ...byte) ([]byte, *CipherState, *CipherState, error) {
|
||||||
out, _, cs1, cs2, err := s.WriteMessageAndGetPK(out, [][]byte{}, payload)
|
out, _, cs1, cs2, err := s.WriteMessageAndGetPK(out, [][]byte{}, payload, extraAd)
|
||||||
return out, cs1, cs2, err
|
return out, cs1, cs2, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -372,7 +388,7 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
|
||||||
// one is used for encryption of messages to the remote peer, the other is used
|
// one is used for encryption of messages to the remote peer, the other is used
|
||||||
// for decryption of messages from the remote peer. It is an error to call this
|
// for decryption of messages from the remote peer. It is an error to call this
|
||||||
// method out of sync with the handshake pattern.
|
// method out of sync with the handshake pattern.
|
||||||
func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payload []byte) ([]byte, [][]byte, *CipherState, *CipherState, error) {
|
func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payload []byte, extraAd []byte) ([]byte, [][]byte, *CipherState, *CipherState, error) {
|
||||||
if !s.shouldWrite {
|
if !s.shouldWrite {
|
||||||
return nil, nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
|
return nil, nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
|
||||||
}
|
}
|
||||||
|
@ -455,7 +471,7 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payloa
|
||||||
}
|
}
|
||||||
s.shouldWrite = false
|
s.shouldWrite = false
|
||||||
s.msgIdx++
|
s.msgIdx++
|
||||||
out, err = s.ss.EncryptAndHash(out, payload)
|
out, err = s.ss.EncryptAndHash(out, payload, extraAd...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -468,6 +484,10 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payloa
|
||||||
return out, outPK, nil, nil, nil
|
return out, outPK, nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *HandshakeState) Hash() hash.Hash {
|
||||||
|
return s.ss.cs.Hash()
|
||||||
|
}
|
||||||
|
|
||||||
// ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.
|
// ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.
|
||||||
var ErrShortMessage = errors.New("noise: message is too short")
|
var ErrShortMessage = errors.New("noise: message is too short")
|
||||||
|
|
||||||
|
@ -476,7 +496,7 @@ var ErrShortMessage = errors.New("noise: message is too short")
|
||||||
// will be returned, one is used for encryption of messages to the remote peer,
|
// will be returned, one is used for encryption of messages to the remote peer,
|
||||||
// the other is used for decryption of messages from the remote peer. It is an
|
// the other is used for decryption of messages from the remote peer. It is an
|
||||||
// error to call this method out of sync with the handshake pattern.
|
// error to call this method out of sync with the handshake pattern.
|
||||||
func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, *CipherState, error) {
|
func (s *HandshakeState) ReadMessage(out, message []byte, extraAd ...byte) ([]byte, *CipherState, *CipherState, error) {
|
||||||
if s.shouldWrite {
|
if s.shouldWrite {
|
||||||
return nil, nil, nil, errors.New("noise: unexpected call to ReadMessage should be WriteMessage")
|
return nil, nil, nil, errors.New("noise: unexpected call to ReadMessage should be WriteMessage")
|
||||||
}
|
}
|
||||||
|
@ -568,7 +588,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
|
||||||
s.ss.MixKeyAndHash(s.psk)
|
s.ss.MixKeyAndHash(s.psk)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
out, err = s.ss.DecryptAndHash(out, message)
|
out, err = s.ss.DecryptAndHash(out, message, extraAd...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.ss.Rollback()
|
s.ss.Rollback()
|
||||||
if rsSet {
|
if rsSet {
|
||||||
|
|
Loading…
Reference in New Issue