Compare commits

..

5 Commits

Author SHA1 Message Date
Richard Ramos
b2cae9b389
Add generateKeyPairFromPrivateKey 2022-12-14 12:16:37 -04:00
Richard Ramos
815c0ed47c
Expose Hash, RS, H and add ad to Encrypt funcs 2022-12-14 12:16:29 -04:00
Richard Ramos
b14b0d0806
Change project org 2022-12-14 12:16:20 -04:00
Richard Ramos
166d5e87be
Deduplicate WriteMessage content 2022-12-14 12:16:10 -04:00
Richard Ramos
da2a9c978d
Upgrade to go 1.17 and rename module for easier integration 2022-12-14 12:16:01 -04:00
4 changed files with 68 additions and 123 deletions

View File

@ -28,6 +28,9 @@ type DHFunc interface {
// entropy. // entropy.
GenerateKeypair(random io.Reader) (DHKey, error) GenerateKeypair(random io.Reader) (DHKey, error)
// GenerateKeypairFromPrivateKEy generates a keypair from a private key
GenerateKeyPairFromPrivateKey(privkey []byte) (DHKey, error)
// DH performs a Diffie-Hellman calculation between the provided private and // DH performs a Diffie-Hellman calculation between the provided private and
// public keys and returns the result. // public keys and returns the result.
DH(privkey, pubkey []byte) ([]byte, error) DH(privkey, pubkey []byte) ([]byte, error)
@ -104,7 +107,7 @@ var DH25519 DHFunc = dh25519{}
type dh25519 struct{} type dh25519 struct{}
func (dh25519) GenerateKeypair(rng io.Reader) (DHKey, error) { func (d dh25519) GenerateKeypair(rng io.Reader) (DHKey, error) {
privkey := make([]byte, 32) privkey := make([]byte, 32)
if rng == nil { if rng == nil {
rng = rand.Reader rng = rand.Reader
@ -112,6 +115,11 @@ func (dh25519) GenerateKeypair(rng io.Reader) (DHKey, error) {
if _, err := io.ReadFull(rng, privkey); err != nil { if _, err := io.ReadFull(rng, privkey); err != nil {
return DHKey{}, err return DHKey{}, err
} }
return d.GenerateKeyPairFromPrivateKey(privkey)
}
func (d dh25519) GenerateKeyPairFromPrivateKey(privkey []byte) (DHKey, error) {
pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
if err != nil { if err != nil {
return DHKey{}, err return DHKey{}, err

10
go.mod
View File

@ -1,8 +1,14 @@
module github.com/flynn/noise module github.com/waku-org/noise
go 1.16 go 1.17
require ( require (
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c
) )
require (
github.com/kr/pretty v0.2.1 // indirect
github.com/kr/text v0.1.0 // indirect
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 // indirect
)

169
state.go
View File

@ -10,6 +10,7 @@ import (
"crypto/rand" "crypto/rand"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"math" "math"
) )
@ -149,12 +150,16 @@ func (s *symmetricState) MixKeyAndHash(data []byte) {
s.hasK = true s.hasK = true
} }
func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) { // Note that by setting extraAd, it is possible to pass extra additional data that will be concatenated to the ad specified by Noise (can be used to authenticate messageNametag)
func (s *symmetricState) EncryptAndHash(out, plaintext []byte, extraAd ...byte) ([]byte, error) {
if !s.hasK { if !s.hasK {
s.MixHash(plaintext) s.MixHash(plaintext)
return append(out, plaintext...), nil return append(out, plaintext...), nil
} }
ciphertext, err := s.Encrypt(out, s.h, plaintext)
ad := append([]byte(nil), s.h...)
ad = append(ad, extraAd...)
ciphertext, err := s.Encrypt(out, ad, plaintext)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -162,12 +167,15 @@ func (s *symmetricState) EncryptAndHash(out, plaintext []byte) ([]byte, error) {
return ciphertext, nil return ciphertext, nil
} }
func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) { func (s *symmetricState) DecryptAndHash(out, data []byte, extraAd ...byte) ([]byte, error) {
if !s.hasK { if !s.hasK {
s.MixHash(data) s.MixHash(data)
return append(out, data...), nil return append(out, data...), nil
} }
plaintext, err := s.Decrypt(out, s.h, data)
ad := append([]byte(nil), s.h...)
ad = append(ad, extraAd...)
plaintext, err := s.Decrypt(out, ad, data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -355,102 +363,23 @@ func NewHandshakeState(c Config) (*HandshakeState, error) {
return hs, nil return hs, nil
} }
func (s *HandshakeState) H() []byte {
return append([]byte(nil), s.ss.h...)
}
func (s *HandshakeState) RS() []byte {
return append([]byte(nil), s.rs...)
}
// WriteMessage appends a handshake message to out. The message will include the // WriteMessage appends a handshake message to out. The message will include the
// optional payload if provided. If the handshake is completed by the call, two // optional payload if provided. If the handshake is completed by the call, two
// CipherStates will be returned, one is used for encryption of messages to the // CipherStates will be returned, one is used for encryption of messages to the
// remote peer, the other is used for decryption of messages from the remote // remote peer, the other is used for decryption of messages from the remote
// peer. It is an error to call this method out of sync with the handshake // peer. It is an error to call this method out of sync with the handshake
// pattern. // pattern.
func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState, *CipherState, error) { func (s *HandshakeState) WriteMessage(out, payload []byte, extraAd ...byte) ([]byte, *CipherState, *CipherState, error) {
if !s.shouldWrite { out, _, cs1, cs2, err := s.WriteMessageAndGetPK(out, [][]byte{}, payload, extraAd)
return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage") return out, cs1, cs2, err
}
if s.msgIdx > len(s.messagePatterns)-1 {
return nil, nil, nil, errors.New("noise: no handshake messages left")
}
if len(payload) > MaxMsgLen {
return nil, nil, nil, errors.New("noise: message is too long")
}
var err error
for _, msg := range s.messagePatterns[s.msgIdx] {
switch msg {
case MessagePatternE:
e, err := s.ss.cs.GenerateKeypair(s.rng)
if err != nil {
return nil, nil, nil, err
}
s.e = e
out = append(out, s.e.Public...)
s.ss.MixHash(s.e.Public)
if len(s.psk) > 0 {
s.ss.MixKey(s.e.Public)
}
case MessagePatternS:
if len(s.s.Public) == 0 {
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
}
out, err = s.ss.EncryptAndHash(out, s.s.Public)
if err != nil {
return nil, nil, nil, err
}
case MessagePatternDHEE:
dh, err := s.ss.cs.DH(s.e.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternDHES:
if s.initiator {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSE:
if s.initiator {
dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
} else {
dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
}
case MessagePatternDHSS:
dh, err := s.ss.cs.DH(s.s.Private, s.rs)
if err != nil {
return nil, nil, nil, err
}
s.ss.MixKey(dh)
case MessagePatternPSK:
s.ss.MixKeyAndHash(s.psk)
}
}
s.shouldWrite = false
s.msgIdx++
out, err = s.ss.EncryptAndHash(out, payload)
if err != nil {
return nil, nil, nil, err
}
if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.ss.Split()
return out, cs1, cs2, nil
}
return out, nil, nil, nil
} }
// WriteMessageAndGetPK appends a handshake message to out. outPK can possibly contain the // WriteMessageAndGetPK appends a handshake message to out. outPK can possibly contain the
@ -459,15 +388,15 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState
// one is used for encryption of messages to the remote peer, the other is used // one is used for encryption of messages to the remote peer, the other is used
// for decryption of messages from the remote peer. It is an error to call this // for decryption of messages from the remote peer. It is an error to call this
// method out of sync with the handshake pattern. // method out of sync with the handshake pattern.
func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte, payload []byte) ([]byte, *CipherState, *CipherState, error) { func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outPK [][]byte, payload []byte, extraAd []byte) ([]byte, [][]byte, *CipherState, *CipherState, error) {
if !s.shouldWrite { if !s.shouldWrite {
return nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage") return nil, nil, nil, nil, errors.New("noise: unexpected call to WriteMessage should be ReadMessage")
} }
if s.msgIdx > len(s.messagePatterns)-1 { if s.msgIdx > len(s.messagePatterns)-1 {
return nil, nil, nil, errors.New("noise: no handshake messages left") return nil, nil, nil, nil, errors.New("noise: no handshake messages left")
} }
if len(payload) > MaxMsgLen { if len(payload) > MaxMsgLen {
return nil, nil, nil, errors.New("noise: message is too long") return nil, nil, nil, nil, errors.New("noise: message is too long")
} }
var err error var err error
@ -476,45 +405,43 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
case MessagePatternE: case MessagePatternE:
e, err := s.ss.cs.GenerateKeypair(s.rng) e, err := s.ss.cs.GenerateKeypair(s.rng)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.e = e s.e = e
out = append(out, s.e.Public...) out = append(out, s.e.Public...)
if outNoisePK != nil { outPK = append(outPK, s.e.Public)
*outNoisePK = append(*outNoisePK, s.e.Public)
}
s.ss.MixHash(s.e.Public) s.ss.MixHash(s.e.Public)
if len(s.psk) > 0 { if len(s.psk) > 0 {
s.ss.MixKey(s.e.Public) s.ss.MixKey(s.e.Public)
} }
case MessagePatternS: case MessagePatternS:
if len(s.s.Public) == 0 { if len(s.s.Public) == 0 {
return nil, nil, nil, errors.New("noise: invalid state, s.Public is nil") return nil, nil, nil, nil, errors.New("noise: invalid state, s.Public is nil")
} }
out, err = s.ss.EncryptAndHash(out, s.s.Public) out, err = s.ss.EncryptAndHash(out, s.s.Public)
outPK = append(outPK, out)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
}
if outNoisePK != nil {
*outNoisePK = append(*outNoisePK, out)
} }
case MessagePatternDHEE: case MessagePatternDHEE:
dh, err := s.ss.cs.DH(s.e.Private, s.re) dh, err := s.ss.cs.DH(s.e.Private, s.re)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.ss.MixKey(dh) s.ss.MixKey(dh)
case MessagePatternDHES: case MessagePatternDHES:
if s.initiator { if s.initiator {
dh, err := s.ss.cs.DH(s.e.Private, s.rs) dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.ss.MixKey(dh) s.ss.MixKey(dh)
} else { } else {
dh, err := s.ss.cs.DH(s.s.Private, s.re) dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.ss.MixKey(dh) s.ss.MixKey(dh)
} }
@ -522,20 +449,20 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
if s.initiator { if s.initiator {
dh, err := s.ss.cs.DH(s.s.Private, s.re) dh, err := s.ss.cs.DH(s.s.Private, s.re)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.ss.MixKey(dh) s.ss.MixKey(dh)
} else { } else {
dh, err := s.ss.cs.DH(s.e.Private, s.rs) dh, err := s.ss.cs.DH(s.e.Private, s.rs)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.ss.MixKey(dh) s.ss.MixKey(dh)
} }
case MessagePatternDHSS: case MessagePatternDHSS:
dh, err := s.ss.cs.DH(s.s.Private, s.rs) dh, err := s.ss.cs.DH(s.s.Private, s.rs)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
s.ss.MixKey(dh) s.ss.MixKey(dh)
case MessagePatternPSK: case MessagePatternPSK:
@ -544,17 +471,21 @@ func (s *HandshakeState) WriteMessageAndGetPK(out []byte, outNoisePK *[][]byte,
} }
s.shouldWrite = false s.shouldWrite = false
s.msgIdx++ s.msgIdx++
out, err = s.ss.EncryptAndHash(out, payload) out, err = s.ss.EncryptAndHash(out, payload, extraAd...)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, nil, err
} }
if s.msgIdx >= len(s.messagePatterns) { if s.msgIdx >= len(s.messagePatterns) {
cs1, cs2 := s.ss.Split() cs1, cs2 := s.ss.Split()
return out, cs1, cs2, nil return out, outPK, cs1, cs2, nil
} }
return out, nil, nil, nil return out, outPK, nil, nil, nil
}
func (s *HandshakeState) Hash() hash.Hash {
return s.ss.cs.Hash()
} }
// ErrShortMessage is returned by ReadMessage if a message is not as long as it should be. // ErrShortMessage is returned by ReadMessage if a message is not as long as it should be.
@ -565,7 +496,7 @@ var ErrShortMessage = errors.New("noise: message is too short")
// will be returned, one is used for encryption of messages to the remote peer, // will be returned, one is used for encryption of messages to the remote peer,
// the other is used for decryption of messages from the remote peer. It is an // the other is used for decryption of messages from the remote peer. It is an
// error to call this method out of sync with the handshake pattern. // error to call this method out of sync with the handshake pattern.
func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, *CipherState, error) { func (s *HandshakeState) ReadMessage(out, message []byte, extraAd ...byte) ([]byte, *CipherState, *CipherState, error) {
if s.shouldWrite { if s.shouldWrite {
return nil, nil, nil, errors.New("noise: unexpected call to ReadMessage should be WriteMessage") return nil, nil, nil, errors.New("noise: unexpected call to ReadMessage should be WriteMessage")
} }
@ -657,7 +588,7 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState,
s.ss.MixKeyAndHash(s.psk) s.ss.MixKeyAndHash(s.psk)
} }
} }
out, err = s.ss.DecryptAndHash(out, message) out, err = s.ss.DecryptAndHash(out, message, extraAd...)
if err != nil { if err != nil {
s.ss.Rollback() s.ss.Rollback()
if rsSet { if rsSet {

View File

@ -7,7 +7,7 @@ import (
"io" "io"
"os" "os"
. "github.com/flynn/noise" . "github.com/waku-org/noise"
) )
func main() { func main() {