make handshake state local to runHandshake
This commit is contained in:
parent
1edb96a9e1
commit
dd7ccf8247
|
@ -3,20 +3,20 @@ package noise
|
|||
import "errors"
|
||||
|
||||
func (s *secureSession) encrypt(plaintext []byte) (ciphertext []byte, err error) {
|
||||
if s.ns.enc == nil {
|
||||
if s.enc == nil {
|
||||
return nil, errors.New("cannot encrypt, handshake incomplete")
|
||||
}
|
||||
|
||||
// TODO: use pre-allocated buffers
|
||||
ciphertext = s.ns.enc.Encrypt(nil, nil, plaintext)
|
||||
ciphertext = s.enc.Encrypt(nil, nil, plaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
func (s *secureSession) decrypt(ciphertext []byte) (plaintext []byte, err error) {
|
||||
if s.ns.dec == nil {
|
||||
if s.dec == nil {
|
||||
return nil, errors.New("cannot decrypt, handshake incomplete")
|
||||
}
|
||||
|
||||
// TODO: use pre-allocated buffers
|
||||
return s.ns.dec.Decrypt(nil, nil, ciphertext)
|
||||
return s.dec.Decrypt(nil, nil, ciphertext)
|
||||
}
|
||||
|
|
|
@ -41,62 +41,57 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
|||
return fmt.Errorf("error initializing handshake state: %s", err)
|
||||
}
|
||||
|
||||
s.ns.hs = hs
|
||||
s.ns.localStatic = kp
|
||||
|
||||
payload, err := s.generateHandshakePayload()
|
||||
payload, err := s.generateHandshakePayload(kp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.initiator {
|
||||
// stage 0 //
|
||||
err = s.sendHandshakeMessage(nil)
|
||||
err = s.sendHandshakeMessage(hs, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %s", err)
|
||||
}
|
||||
|
||||
// stage 1 //
|
||||
plaintext, err := s.readHandshakeMessage()
|
||||
plaintext, err := s.readHandshakeMessage(hs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading handshake message: %s", err)
|
||||
}
|
||||
err = s.handleRemoteHandshakePayload(plaintext)
|
||||
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// stage 2 //
|
||||
err = s.sendHandshakeMessage(payload)
|
||||
err = s.sendHandshakeMessage(hs, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %s", err)
|
||||
}
|
||||
} else {
|
||||
// stage 0 //
|
||||
plaintext, err := s.readHandshakeMessage()
|
||||
plaintext, err := s.readHandshakeMessage(hs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading handshake message: %s", err)
|
||||
}
|
||||
|
||||
// stage 1 //
|
||||
err = s.sendHandshakeMessage(payload)
|
||||
err = s.sendHandshakeMessage(hs, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %s", err)
|
||||
}
|
||||
|
||||
// stage 2 //
|
||||
plaintext, err = s.readHandshakeMessage()
|
||||
plaintext, err = s.readHandshakeMessage(hs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading handshake message: %s", err)
|
||||
}
|
||||
err = s.handleRemoteHandshakePayload(plaintext)
|
||||
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// we can discard the handshake state once the handshake completes
|
||||
s.ns.hs = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -105,21 +100,20 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
|||
// It sets the initial cipher states that will be used to protect traffic after the handshake.
|
||||
func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) {
|
||||
if s.initiator {
|
||||
s.ns.enc = cs1
|
||||
s.ns.dec = cs2
|
||||
s.enc = cs1
|
||||
s.dec = cs2
|
||||
} else {
|
||||
s.ns.enc = cs2
|
||||
s.ns.dec = cs1
|
||||
s.enc = cs2
|
||||
s.dec = cs1
|
||||
}
|
||||
}
|
||||
|
||||
// sendHandshakeMessage sends the next handshake message in the sequence.
|
||||
// Only safe to call from runHandshake, as it depends on handshake state.
|
||||
// If payload is non-empty, it will be included in the handshake message.
|
||||
// If this is the final message in the sequence, calls setCipherStates
|
||||
// to initialize cipher states.
|
||||
func (s *secureSession) sendHandshakeMessage(payload []byte) error {
|
||||
buf, cs1, cs2, err := s.ns.hs.WriteMessage(nil, payload)
|
||||
func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte) error {
|
||||
buf, cs1, cs2, err := hs.WriteMessage(nil, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -137,16 +131,15 @@ func (s *secureSession) sendHandshakeMessage(payload []byte) error {
|
|||
|
||||
// readHandshakeMessage reads a message from the insecure conn and tries to
|
||||
// process it as the expected next message in the handshake sequence.
|
||||
// Only safe to call from runHandshake, as it depends on handshake state.
|
||||
// If the message contains a payload, it will be decrypted and returned.
|
||||
// If this is the final message in the sequence, calls setCipherStates
|
||||
// to initialize cipher states.
|
||||
func (s *secureSession) readHandshakeMessage() ([]byte, error) {
|
||||
func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte, error) {
|
||||
raw, err := s.readMsgInsecure()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg, cs1, cs2, err := s.ns.hs.ReadMessage(nil, raw)
|
||||
msg, cs1, cs2, err := hs.ReadMessage(nil, raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -158,8 +151,7 @@ func (s *secureSession) readHandshakeMessage() ([]byte, error) {
|
|||
|
||||
// generateHandshakePayload creates a libp2p handshake payload with a
|
||||
// signature of our static noise key.
|
||||
// Must be called after the static key for the session has been generated.
|
||||
func (s *secureSession) generateHandshakePayload() ([]byte, error) {
|
||||
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byte, error) {
|
||||
// setup libp2p keys
|
||||
localKeyRaw, err := s.LocalPublicKey().Bytes()
|
||||
if err != nil {
|
||||
|
@ -167,7 +159,7 @@ func (s *secureSession) generateHandshakePayload() ([]byte, error) {
|
|||
}
|
||||
|
||||
// sign noise data for payload
|
||||
toSign := append([]byte(payloadSigPrefix), s.ns.localStatic.Public...)
|
||||
toSign := append([]byte(payloadSigPrefix), localStatic.Public...)
|
||||
signedPayload, err := s.localKey.Sign(toSign)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error sigining handshake payload: %s", err)
|
||||
|
@ -186,8 +178,7 @@ func (s *secureSession) generateHandshakePayload() ([]byte, error) {
|
|||
|
||||
// handleRemoteHandshakePayload unmarshals the handshake payload object sent
|
||||
// by the remote peer and validates the signature against the peer's static Noise key.
|
||||
// Only safe to call from runHandshake, as it depends on handshake state.
|
||||
func (s *secureSession) handleRemoteHandshakePayload(payload []byte) error {
|
||||
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) error {
|
||||
// unmarshal payload
|
||||
nhp := new(pb.NoiseHandshakePayload)
|
||||
err := proto.Unmarshal(payload, nhp)
|
||||
|
@ -212,7 +203,6 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte) error {
|
|||
|
||||
// verify payload is signed by libp2p key
|
||||
sig := nhp.GetIdentitySig()
|
||||
remoteStatic := s.ns.hs.PeerStatic()
|
||||
msg := append([]byte(payloadSigPrefix), remoteStatic...)
|
||||
ok, err := remotePubKey.Verify(msg, sig)
|
||||
if err != nil {
|
||||
|
|
|
@ -7,21 +7,13 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
)
|
||||
|
||||
type noiseState struct {
|
||||
localStatic noise.DHKey
|
||||
|
||||
hs *noise.HandshakeState
|
||||
enc *noise.CipherState
|
||||
dec *noise.CipherState
|
||||
}
|
||||
|
||||
type secureSession struct {
|
||||
initiator bool
|
||||
ns noiseState
|
||||
|
||||
localID peer.ID
|
||||
localKey crypto.PrivKey
|
||||
|
@ -32,6 +24,9 @@ type secureSession struct {
|
|||
msgBuffer []byte
|
||||
readLock sync.Mutex
|
||||
writeLock sync.Mutex
|
||||
|
||||
enc *noise.CipherState
|
||||
dec *noise.CipherState
|
||||
}
|
||||
|
||||
// newSecureSession creates a noise session over the given insecure Conn, using the
|
||||
|
|
Loading…
Reference in New Issue