From fc2bb37e287bd583b6699a2580710dc4c9e2deec Mon Sep 17 00:00:00 2001 From: Jonathan Rudenberg Date: Thu, 22 Apr 2021 13:00:17 -0400 Subject: [PATCH] Use X25519 instead of ScalarMult for safety (#43) --- cipher_suite.go | 21 +++++++-------- state.go | 72 ++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 70 insertions(+), 23 deletions(-) diff --git a/cipher_suite.go b/cipher_suite.go index 3098705..753e011 100644 --- a/cipher_suite.go +++ b/cipher_suite.go @@ -30,7 +30,7 @@ type DHFunc interface { // DH performs a Diffie-Hellman calculation between the provided private and // public keys and returns the result. - DH(privkey, pubkey []byte) []byte + DH(privkey, pubkey []byte) ([]byte, error) // DHLen is the number of bytes returned by DH. DHLen() int @@ -105,23 +105,22 @@ var DH25519 DHFunc = dh25519{} type dh25519 struct{} func (dh25519) GenerateKeypair(rng io.Reader) (DHKey, error) { - var pubkey, privkey [32]byte + privkey := make([]byte, 32) if rng == nil { rng = rand.Reader } - if _, err := io.ReadFull(rng, privkey[:]); err != nil { + if _, err := io.ReadFull(rng, privkey); err != nil { return DHKey{}, err } - curve25519.ScalarBaseMult(&pubkey, &privkey) - return DHKey{Private: privkey[:], Public: pubkey[:]}, nil + pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) + if err != nil { + return DHKey{}, err + } + return DHKey{Private: privkey, Public: pubkey}, nil } -func (dh25519) DH(privkey, pubkey []byte) []byte { - var dst, in, base [32]byte - copy(in[:], privkey) - copy(base[:], pubkey) - curve25519.ScalarMult(&dst, &in, &base) - return dst[:] +func (dh25519) DH(privkey, pubkey []byte) ([]byte, error) { + return curve25519.X25519(privkey, pubkey) } func (dh25519) DHLen() int { return 32 } diff --git a/state.go b/state.go index c4f3161..418522e 100644 --- a/state.go +++ b/state.go @@ -359,21 +359,45 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState } out = s.ss.EncryptAndHash(out, s.s.Public) case MessagePatternDHEE: - s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.re)) + dh, err := s.ss.cs.DH(s.e.Private, s.re) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) case MessagePatternDHES: if s.initiator { - s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.rs)) + dh, err := s.ss.cs.DH(s.e.Private, s.rs) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) } else { - s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.re)) + dh, err := s.ss.cs.DH(s.s.Private, s.re) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) } case MessagePatternDHSE: if s.initiator { - s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.re)) + dh, err := s.ss.cs.DH(s.s.Private, s.re) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) } else { - s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.rs)) + dh, err := s.ss.cs.DH(s.e.Private, s.rs) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) } case MessagePatternDHSS: - s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs)) + dh, err := s.ss.cs.DH(s.s.Private, s.rs) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) case MessagePatternPSK: s.ss.MixKeyAndHash(s.psk) } @@ -447,21 +471,45 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, } message = message[expected:] case MessagePatternDHEE: - s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.re)) + dh, err := s.ss.cs.DH(s.e.Private, s.re) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) case MessagePatternDHES: if s.initiator { - s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.rs)) + dh, err := s.ss.cs.DH(s.e.Private, s.rs) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) } else { - s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.re)) + dh, err := s.ss.cs.DH(s.s.Private, s.re) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) } case MessagePatternDHSE: if s.initiator { - s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.re)) + dh, err := s.ss.cs.DH(s.s.Private, s.re) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) } else { - s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.rs)) + dh, err := s.ss.cs.DH(s.e.Private, s.rs) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) } case MessagePatternDHSS: - s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs)) + dh, err := s.ss.cs.DH(s.s.Private, s.rs) + if err != nil { + return nil, nil, nil, err + } + s.ss.MixKey(dh) case MessagePatternPSK: s.ss.MixKeyAndHash(s.psk) }