diff --git a/p2p/security/noise/crypto.go b/p2p/security/noise/crypto.go index 3048f8c3..86e206aa 100644 --- a/p2p/security/noise/crypto.go +++ b/p2p/security/noise/crypto.go @@ -3,20 +3,20 @@ package noise import "errors" func (s *secureSession) encrypt(plaintext []byte) (ciphertext []byte, err error) { - if s.ns.enc == nil { + if s.enc == nil { return nil, errors.New("cannot encrypt, handshake incomplete") } // TODO: use pre-allocated buffers - ciphertext = s.ns.enc.Encrypt(nil, nil, plaintext) + ciphertext = s.enc.Encrypt(nil, nil, plaintext) return ciphertext, nil } func (s *secureSession) decrypt(ciphertext []byte) (plaintext []byte, err error) { - if s.ns.dec == nil { + if s.dec == nil { return nil, errors.New("cannot decrypt, handshake incomplete") } // TODO: use pre-allocated buffers - return s.ns.dec.Decrypt(nil, nil, ciphertext) + return s.dec.Decrypt(nil, nil, ciphertext) } diff --git a/p2p/security/noise/handshake.go b/p2p/security/noise/handshake.go index cbe5c94f..23c6d637 100644 --- a/p2p/security/noise/handshake.go +++ b/p2p/security/noise/handshake.go @@ -41,62 +41,57 @@ func (s *secureSession) runHandshake(ctx context.Context) error { return fmt.Errorf("error initializing handshake state: %s", err) } - s.ns.hs = hs - s.ns.localStatic = kp - - payload, err := s.generateHandshakePayload() + payload, err := s.generateHandshakePayload(kp) if err != nil { return err } if s.initiator { // stage 0 // - err = s.sendHandshakeMessage(nil) + err = s.sendHandshakeMessage(hs, nil) if err != nil { return fmt.Errorf("error sending handshake message: %s", err) } // stage 1 // - plaintext, err := s.readHandshakeMessage() + plaintext, err := s.readHandshakeMessage(hs) if err != nil { return fmt.Errorf("error reading handshake message: %s", err) } - err = s.handleRemoteHandshakePayload(plaintext) + err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) if err != nil { return err } // stage 2 // - err = s.sendHandshakeMessage(payload) + err = s.sendHandshakeMessage(hs, payload) if err != nil { return fmt.Errorf("error sending handshake message: %s", err) } } else { // stage 0 // - plaintext, err := s.readHandshakeMessage() + plaintext, err := s.readHandshakeMessage(hs) if err != nil { return fmt.Errorf("error reading handshake message: %s", err) } // stage 1 // - err = s.sendHandshakeMessage(payload) + err = s.sendHandshakeMessage(hs, payload) if err != nil { return fmt.Errorf("error sending handshake message: %s", err) } // stage 2 // - plaintext, err = s.readHandshakeMessage() + plaintext, err = s.readHandshakeMessage(hs) if err != nil { return fmt.Errorf("error reading handshake message: %s", err) } - err = s.handleRemoteHandshakePayload(plaintext) + err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) if err != nil { return err } } - // we can discard the handshake state once the handshake completes - s.ns.hs = nil return nil } @@ -105,21 +100,20 @@ func (s *secureSession) runHandshake(ctx context.Context) error { // It sets the initial cipher states that will be used to protect traffic after the handshake. func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) { if s.initiator { - s.ns.enc = cs1 - s.ns.dec = cs2 + s.enc = cs1 + s.dec = cs2 } else { - s.ns.enc = cs2 - s.ns.dec = cs1 + s.enc = cs2 + s.dec = cs1 } } // sendHandshakeMessage sends the next handshake message in the sequence. -// Only safe to call from runHandshake, as it depends on handshake state. // If payload is non-empty, it will be included in the handshake message. // If this is the final message in the sequence, calls setCipherStates // to initialize cipher states. -func (s *secureSession) sendHandshakeMessage(payload []byte) error { - buf, cs1, cs2, err := s.ns.hs.WriteMessage(nil, payload) +func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte) error { + buf, cs1, cs2, err := hs.WriteMessage(nil, payload) if err != nil { return err } @@ -137,16 +131,15 @@ func (s *secureSession) sendHandshakeMessage(payload []byte) error { // readHandshakeMessage reads a message from the insecure conn and tries to // process it as the expected next message in the handshake sequence. -// Only safe to call from runHandshake, as it depends on handshake state. // If the message contains a payload, it will be decrypted and returned. // If this is the final message in the sequence, calls setCipherStates // to initialize cipher states. -func (s *secureSession) readHandshakeMessage() ([]byte, error) { +func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte, error) { raw, err := s.readMsgInsecure() if err != nil { return nil, err } - msg, cs1, cs2, err := s.ns.hs.ReadMessage(nil, raw) + msg, cs1, cs2, err := hs.ReadMessage(nil, raw) if err != nil { return nil, err } @@ -158,8 +151,7 @@ func (s *secureSession) readHandshakeMessage() ([]byte, error) { // generateHandshakePayload creates a libp2p handshake payload with a // signature of our static noise key. -// Must be called after the static key for the session has been generated. -func (s *secureSession) generateHandshakePayload() ([]byte, error) { +func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byte, error) { // setup libp2p keys localKeyRaw, err := s.LocalPublicKey().Bytes() if err != nil { @@ -167,7 +159,7 @@ func (s *secureSession) generateHandshakePayload() ([]byte, error) { } // sign noise data for payload - toSign := append([]byte(payloadSigPrefix), s.ns.localStatic.Public...) + toSign := append([]byte(payloadSigPrefix), localStatic.Public...) signedPayload, err := s.localKey.Sign(toSign) if err != nil { return nil, fmt.Errorf("error sigining handshake payload: %s", err) @@ -186,8 +178,7 @@ func (s *secureSession) generateHandshakePayload() ([]byte, error) { // handleRemoteHandshakePayload unmarshals the handshake payload object sent // by the remote peer and validates the signature against the peer's static Noise key. -// Only safe to call from runHandshake, as it depends on handshake state. -func (s *secureSession) handleRemoteHandshakePayload(payload []byte) error { +func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) error { // unmarshal payload nhp := new(pb.NoiseHandshakePayload) err := proto.Unmarshal(payload, nhp) @@ -212,7 +203,6 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte) error { // verify payload is signed by libp2p key sig := nhp.GetIdentitySig() - remoteStatic := s.ns.hs.PeerStatic() msg := append([]byte(payloadSigPrefix), remoteStatic...) ok, err := remotePubKey.Verify(msg, sig) if err != nil { diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index fe176551..a88619af 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -7,21 +7,13 @@ import ( "time" "github.com/flynn/noise" + "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" ) -type noiseState struct { - localStatic noise.DHKey - - hs *noise.HandshakeState - enc *noise.CipherState - dec *noise.CipherState -} - type secureSession struct { initiator bool - ns noiseState localID peer.ID localKey crypto.PrivKey @@ -32,6 +24,9 @@ type secureSession struct { msgBuffer []byte readLock sync.Mutex writeLock sync.Mutex + + enc *noise.CipherState + dec *noise.CipherState } // newSecureSession creates a noise session over the given insecure Conn, using the