From 9a626dbd0b73f5d17c0cca61ba2aa560120da1a3 Mon Sep 17 00:00:00 2001 From: Jonathan Rudenberg Date: Tue, 12 Jul 2016 21:25:40 -0400 Subject: [PATCH] Unexport SymmetricState This struct is only used internally. Signed-off-by: Jonathan Rudenberg --- state.go | 94 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 47 insertions(+), 47 deletions(-) diff --git a/state.go b/state.go index 22182cb..5cfa789 100644 --- a/state.go +++ b/state.go @@ -24,7 +24,7 @@ func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) { return out, err } -type SymmetricState struct { +type symmetricState struct { CipherState hasK bool hasPSK bool @@ -32,7 +32,7 @@ type SymmetricState struct { h []byte } -func (s *SymmetricState) InitializeSymmetric(handshakeName []byte) { +func (s *symmetricState) InitializeSymmetric(handshakeName []byte) { h := s.cs.Hash() if len(handshakeName) <= h.Size() { s.h = make([]byte, h.Size()) @@ -45,7 +45,7 @@ func (s *SymmetricState) InitializeSymmetric(handshakeName []byte) { copy(s.ck, s.h) } -func (s *SymmetricState) MixKey(dhOutput []byte) { +func (s *symmetricState) MixKey(dhOutput []byte) { s.n = 0 s.hasK = true var hk []byte @@ -54,21 +54,21 @@ func (s *SymmetricState) MixKey(dhOutput []byte) { s.c = s.cs.Cipher(s.k) } -func (s *SymmetricState) MixHash(data []byte) { +func (s *symmetricState) MixHash(data []byte) { h := s.cs.Hash() h.Write(s.h) h.Write(data) s.h = h.Sum(s.h[:0]) } -func (s *SymmetricState) MixPresharedKey(presharedKey []byte) { +func (s *symmetricState) MixPresharedKey(presharedKey []byte) { var temp []byte s.ck, temp = HKDF(s.cs.Hash, s.ck[:0], nil, s.ck, presharedKey) s.MixHash(temp) s.hasPSK = true } -func (s *SymmetricState) EncryptAndHash(out, plaintext []byte) []byte { +func (s *symmetricState) EncryptAndHash(out, plaintext []byte) []byte { if !s.hasK { s.MixHash(plaintext) return append(out, plaintext...) @@ -78,7 +78,7 @@ func (s *SymmetricState) EncryptAndHash(out, plaintext []byte) []byte { return ciphertext } -func (s *SymmetricState) DecryptAndHash(out, data []byte) ([]byte, error) { +func (s *symmetricState) DecryptAndHash(out, data []byte) ([]byte, error) { if !s.hasK { s.MixHash(data) return append(out, data...), nil @@ -91,7 +91,7 @@ func (s *SymmetricState) DecryptAndHash(out, data []byte) ([]byte, error) { return plaintext, nil } -func (s *SymmetricState) Split() (*CipherState, *CipherState) { +func (s *symmetricState) Split() (*CipherState, *CipherState) { s1, s2 := &CipherState{cs: s.cs}, &CipherState{cs: s.cs} hk1, hk2 := HKDF(s.cs.Hash, s1.k[:0], s2.k[:0], s.ck, nil) copy(s1.k[:], hk1) @@ -122,7 +122,7 @@ const ( const MaxMsgLen = 65535 type HandshakeState struct { - SymmetricState + ss symmetricState s DHKey // local static keypair e DHKey // local ephemeral keypair rs []byte // remote party's static public key @@ -159,38 +159,38 @@ func NewHandshakeState(c Config) *HandshakeState { hs.re = make([]byte, len(c.PeerEphemeral)) copy(hs.re, c.PeerEphemeral) } - hs.SymmetricState.cs = c.CipherSuite + hs.ss.cs = c.CipherSuite namePrefix := "Noise_" if len(c.PresharedKey) > 0 { namePrefix = "NoisePSK_" } - hs.InitializeSymmetric([]byte(namePrefix + c.Pattern.Name + "_" + string(hs.cs.Name()))) - hs.MixHash(c.Prologue) + hs.ss.InitializeSymmetric([]byte(namePrefix + c.Pattern.Name + "_" + string(hs.ss.cs.Name()))) + hs.ss.MixHash(c.Prologue) if len(c.PresharedKey) > 0 { - hs.MixPresharedKey(c.PresharedKey) + hs.ss.MixPresharedKey(c.PresharedKey) } for _, m := range c.Pattern.InitiatorPreMessages { switch { case c.Initiator && m == MessagePatternS: - hs.MixHash(hs.s.Public) + hs.ss.MixHash(hs.s.Public) case c.Initiator && m == MessagePatternE: - hs.MixHash(hs.e.Public) + hs.ss.MixHash(hs.e.Public) case !c.Initiator && m == MessagePatternS: - hs.MixHash(hs.rs) + hs.ss.MixHash(hs.rs) case !c.Initiator && m == MessagePatternE: - hs.MixHash(hs.re) + hs.ss.MixHash(hs.re) } } for _, m := range c.Pattern.ResponderPreMessages { switch { case !c.Initiator && m == MessagePatternS: - hs.MixHash(hs.s.Public) + hs.ss.MixHash(hs.s.Public) case !c.Initiator && m == MessagePatternE: - hs.MixHash(hs.e.Public) + hs.ss.MixHash(hs.e.Public) case c.Initiator && m == MessagePatternS: - hs.MixHash(hs.rs) + hs.ss.MixHash(hs.rs) case c.Initiator && m == MessagePatternE: - hs.MixHash(hs.re) + hs.ss.MixHash(hs.re) } } return hs @@ -210,33 +210,33 @@ func (s *HandshakeState) WriteMessage(out, payload []byte) ([]byte, *CipherState for _, msg := range s.messagePatterns[s.msgIdx] { switch msg { case MessagePatternE: - s.e = s.cs.GenerateKeypair(s.rng) + s.e = s.ss.cs.GenerateKeypair(s.rng) out = append(out, s.e.Public...) - s.MixHash(s.e.Public) - if s.hasPSK { - s.MixKey(s.e.Public) + s.ss.MixHash(s.e.Public) + if s.ss.hasPSK { + s.ss.MixKey(s.e.Public) } case MessagePatternS: if len(s.s.Public) == 0 { panic("noise: invalid state, s.Public is nil") } - out = s.EncryptAndHash(out, s.s.Public) + out = s.ss.EncryptAndHash(out, s.s.Public) case MessagePatternDHEE: - s.MixKey(s.cs.DH(s.e.Private, s.re)) + s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.re)) case MessagePatternDHES: - s.MixKey(s.cs.DH(s.e.Private, s.rs)) + s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.rs)) case MessagePatternDHSE: - s.MixKey(s.cs.DH(s.s.Private, s.re)) + s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.re)) case MessagePatternDHSS: - s.MixKey(s.cs.DH(s.s.Private, s.rs)) + s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs)) } } s.shouldWrite = false s.msgIdx++ - out = s.EncryptAndHash(out, payload) + out = s.ss.EncryptAndHash(out, payload) if s.msgIdx >= len(s.messagePatterns) { - cs1, cs2 := s.Split() + cs1, cs2 := s.ss.Split() return out, cs1, cs2 } @@ -257,8 +257,8 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, for _, msg := range s.messagePatterns[s.msgIdx] { switch msg { case MessagePatternE, MessagePatternS: - expected := s.cs.DHLen() - if msg == MessagePatternS && s.hasK { + expected := s.ss.cs.DHLen() + if msg == MessagePatternS && s.ss.hasK { expected += 16 } if len(message) < expected { @@ -266,44 +266,44 @@ func (s *HandshakeState) ReadMessage(out, message []byte) ([]byte, *CipherState, } switch msg { case MessagePatternE: - if cap(s.re) < s.cs.DHLen() { - s.re = make([]byte, s.cs.DHLen()) + if cap(s.re) < s.ss.cs.DHLen() { + s.re = make([]byte, s.ss.cs.DHLen()) } - s.re = s.re[:s.cs.DHLen()] + s.re = s.re[:s.ss.cs.DHLen()] copy(s.re, message) - s.MixHash(s.re) - if s.hasPSK { - s.MixKey(s.re) + s.ss.MixHash(s.re) + if s.ss.hasPSK { + s.ss.MixKey(s.re) } case MessagePatternS: if len(s.rs) > 0 { panic("noise: invalid state, rs is not nil") } - s.rs, err = s.DecryptAndHash(s.rs[:0], message[:expected]) + s.rs, err = s.ss.DecryptAndHash(s.rs[:0], message[:expected]) } if err != nil { return nil, nil, nil, err } message = message[expected:] case MessagePatternDHEE: - s.MixKey(s.cs.DH(s.e.Private, s.re)) + s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.re)) case MessagePatternDHES: - s.MixKey(s.cs.DH(s.s.Private, s.re)) + s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.re)) case MessagePatternDHSE: - s.MixKey(s.cs.DH(s.e.Private, s.rs)) + s.ss.MixKey(s.ss.cs.DH(s.e.Private, s.rs)) case MessagePatternDHSS: - s.MixKey(s.cs.DH(s.s.Private, s.rs)) + s.ss.MixKey(s.ss.cs.DH(s.s.Private, s.rs)) } } s.shouldWrite = true s.msgIdx++ - out, err = s.DecryptAndHash(out, message) + out, err = s.ss.DecryptAndHash(out, message) if err != nil { return nil, nil, nil, err } if s.msgIdx >= len(s.messagePatterns) { - cs1, cs2 := s.Split() + cs1, cs2 := s.ss.Split() return out, cs1, cs2, nil }