make handshake state local to runHandshake

This commit is contained in:
Yusef Napora 2020-03-03 09:31:36 -05:00
parent 1edb96a9e1
commit dd7ccf8247
3 changed files with 28 additions and 43 deletions

View File

@ -3,20 +3,20 @@ package noise
import "errors"
func (s *secureSession) encrypt(plaintext []byte) (ciphertext []byte, err error) {
if s.ns.enc == nil {
if s.enc == nil {
return nil, errors.New("cannot encrypt, handshake incomplete")
}
// TODO: use pre-allocated buffers
ciphertext = s.ns.enc.Encrypt(nil, nil, plaintext)
ciphertext = s.enc.Encrypt(nil, nil, plaintext)
return ciphertext, nil
}
func (s *secureSession) decrypt(ciphertext []byte) (plaintext []byte, err error) {
if s.ns.dec == nil {
if s.dec == nil {
return nil, errors.New("cannot decrypt, handshake incomplete")
}
// TODO: use pre-allocated buffers
return s.ns.dec.Decrypt(nil, nil, ciphertext)
return s.dec.Decrypt(nil, nil, ciphertext)
}

View File

@ -41,62 +41,57 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
return fmt.Errorf("error initializing handshake state: %s", err)
}
s.ns.hs = hs
s.ns.localStatic = kp
payload, err := s.generateHandshakePayload()
payload, err := s.generateHandshakePayload(kp)
if err != nil {
return err
}
if s.initiator {
// stage 0 //
err = s.sendHandshakeMessage(nil)
err = s.sendHandshakeMessage(hs, nil)
if err != nil {
return fmt.Errorf("error sending handshake message: %s", err)
}
// stage 1 //
plaintext, err := s.readHandshakeMessage()
plaintext, err := s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %s", err)
}
err = s.handleRemoteHandshakePayload(plaintext)
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
return err
}
// stage 2 //
err = s.sendHandshakeMessage(payload)
err = s.sendHandshakeMessage(hs, payload)
if err != nil {
return fmt.Errorf("error sending handshake message: %s", err)
}
} else {
// stage 0 //
plaintext, err := s.readHandshakeMessage()
plaintext, err := s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %s", err)
}
// stage 1 //
err = s.sendHandshakeMessage(payload)
err = s.sendHandshakeMessage(hs, payload)
if err != nil {
return fmt.Errorf("error sending handshake message: %s", err)
}
// stage 2 //
plaintext, err = s.readHandshakeMessage()
plaintext, err = s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %s", err)
}
err = s.handleRemoteHandshakePayload(plaintext)
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
return err
}
}
// we can discard the handshake state once the handshake completes
s.ns.hs = nil
return nil
}
@ -105,21 +100,20 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
// It sets the initial cipher states that will be used to protect traffic after the handshake.
func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) {
if s.initiator {
s.ns.enc = cs1
s.ns.dec = cs2
s.enc = cs1
s.dec = cs2
} else {
s.ns.enc = cs2
s.ns.dec = cs1
s.enc = cs2
s.dec = cs1
}
}
// sendHandshakeMessage sends the next handshake message in the sequence.
// Only safe to call from runHandshake, as it depends on handshake state.
// If payload is non-empty, it will be included in the handshake message.
// If this is the final message in the sequence, calls setCipherStates
// to initialize cipher states.
func (s *secureSession) sendHandshakeMessage(payload []byte) error {
buf, cs1, cs2, err := s.ns.hs.WriteMessage(nil, payload)
func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte) error {
buf, cs1, cs2, err := hs.WriteMessage(nil, payload)
if err != nil {
return err
}
@ -137,16 +131,15 @@ func (s *secureSession) sendHandshakeMessage(payload []byte) error {
// readHandshakeMessage reads a message from the insecure conn and tries to
// process it as the expected next message in the handshake sequence.
// Only safe to call from runHandshake, as it depends on handshake state.
// If the message contains a payload, it will be decrypted and returned.
// If this is the final message in the sequence, calls setCipherStates
// to initialize cipher states.
func (s *secureSession) readHandshakeMessage() ([]byte, error) {
func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte, error) {
raw, err := s.readMsgInsecure()
if err != nil {
return nil, err
}
msg, cs1, cs2, err := s.ns.hs.ReadMessage(nil, raw)
msg, cs1, cs2, err := hs.ReadMessage(nil, raw)
if err != nil {
return nil, err
}
@ -158,8 +151,7 @@ func (s *secureSession) readHandshakeMessage() ([]byte, error) {
// generateHandshakePayload creates a libp2p handshake payload with a
// signature of our static noise key.
// Must be called after the static key for the session has been generated.
func (s *secureSession) generateHandshakePayload() ([]byte, error) {
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byte, error) {
// setup libp2p keys
localKeyRaw, err := s.LocalPublicKey().Bytes()
if err != nil {
@ -167,7 +159,7 @@ func (s *secureSession) generateHandshakePayload() ([]byte, error) {
}
// sign noise data for payload
toSign := append([]byte(payloadSigPrefix), s.ns.localStatic.Public...)
toSign := append([]byte(payloadSigPrefix), localStatic.Public...)
signedPayload, err := s.localKey.Sign(toSign)
if err != nil {
return nil, fmt.Errorf("error sigining handshake payload: %s", err)
@ -186,8 +178,7 @@ func (s *secureSession) generateHandshakePayload() ([]byte, error) {
// handleRemoteHandshakePayload unmarshals the handshake payload object sent
// by the remote peer and validates the signature against the peer's static Noise key.
// Only safe to call from runHandshake, as it depends on handshake state.
func (s *secureSession) handleRemoteHandshakePayload(payload []byte) error {
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) error {
// unmarshal payload
nhp := new(pb.NoiseHandshakePayload)
err := proto.Unmarshal(payload, nhp)
@ -212,7 +203,6 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte) error {
// verify payload is signed by libp2p key
sig := nhp.GetIdentitySig()
remoteStatic := s.ns.hs.PeerStatic()
msg := append([]byte(payloadSigPrefix), remoteStatic...)
ok, err := remotePubKey.Verify(msg, sig)
if err != nil {

View File

@ -7,21 +7,13 @@ import (
"time"
"github.com/flynn/noise"
"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
)
type noiseState struct {
localStatic noise.DHKey
hs *noise.HandshakeState
enc *noise.CipherState
dec *noise.CipherState
}
type secureSession struct {
initiator bool
ns noiseState
localID peer.ID
localKey crypto.PrivKey
@ -32,6 +24,9 @@ type secureSession struct {
msgBuffer []byte
readLock sync.Mutex
writeLock sync.Mutex
enc *noise.CipherState
dec *noise.CipherState
}
// newSecureSession creates a noise session over the given insecure Conn, using the