From 5974ecc85295885ad61b1a912c9d2d0586b29392 Mon Sep 17 00:00:00 2001 From: Yusef Napora Date: Mon, 2 Mar 2020 16:33:28 -0500 Subject: [PATCH] factor out processing remote handshake payload --- p2p/security/noise/handshake.go | 98 +++++++++++++-------------------- 1 file changed, 39 insertions(+), 59 deletions(-) diff --git a/p2p/security/noise/handshake.go b/p2p/security/noise/handshake.go index 682bcf2c..a2502345 100644 --- a/p2p/security/noise/handshake.go +++ b/p2p/security/noise/handshake.go @@ -119,13 +119,34 @@ func (s *secureSession) generateHandshakePayload() ([]byte, error) { return payloadEnc, nil } +func (s *secureSession) handleRemoteHandshakePayload(payload []byte) error { + // unmarshal payload + nhp := new(pb.NoiseHandshakePayload) + err := proto.Unmarshal(payload, nhp) + if err != nil { + return fmt.Errorf("error unmarshaling remote handshake payload: %s", err) + } + + // set remote libp2p public key + err = s.setRemotePeerInfo(nhp.GetIdentityKey()) + if err != nil { + return fmt.Errorf("error processing remote libp2p key: %s", err) + } + + // verify payload is signed by libp2p key + err = s.verifyPayload(nhp, s.ns.hs.PeerStatic()) + if err != nil { + return fmt.Errorf("error validating handshake signature: %s", err) + } + return nil +} + // Runs the XX handshake // XX: // -> e // <- e, ee, s, es // -> s, se func (s *secureSession) runHandshake(ctx context.Context) (err error) { - cfg := noise.Config{ CipherSuite: cipherSuite, Pattern: noise.HandshakeXX, @@ -146,88 +167,47 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { if s.ns.initiator { // stage 0 // - err = s.sendHandshakeMessage(nil) if err != nil { return fmt.Errorf("error sending handshake message: %s", err) } // stage 1 // + plaintext, err := s.readHandshakeMessage() + if err != nil { + return fmt.Errorf("error reading handshake message: %s", err) + } + err = s.handleRemoteHandshakePayload(plaintext) + if err != nil { + return err + } - // read reply + // stage 2 // + err = s.sendHandshakeMessage(payload) + if err != nil { + return fmt.Errorf("error sending handshake message: %s", err) + } + } else { + // stage 0 // plaintext, err := s.readHandshakeMessage() if err != nil { return fmt.Errorf("error reading handshake message: %s", err) } - // stage 2 // - err = s.sendHandshakeMessage(payload) - if err != nil { - return fmt.Errorf("error sending handshake message: %s", err) - } - - // unmarshal payload - nhp := new(pb.NoiseHandshakePayload) - err = proto.Unmarshal(plaintext, nhp) - if err != nil { - return fmt.Errorf("error unmarshaling remote handshake payload: %s", err) - } - - // set remote libp2p public key - err = s.setRemotePeerInfo(nhp.GetIdentityKey()) - if err != nil { - return fmt.Errorf("error processing remote libp2p key: %s", err) - } - - // verify payload is signed by libp2p key - err = s.verifyPayload(nhp, s.ns.hs.PeerStatic()) - if err != nil { - return fmt.Errorf("error validating handshake signature: %s", err) - } - - } else { - - // stage 0 // - var plaintext []byte - nhp := new(pb.NoiseHandshakePayload) - - // read message - plaintext, err = s.readHandshakeMessage() - if err != nil { - return fmt.Errorf("error reading handshake message: %s", err) - } - // stage 1 // - err = s.sendHandshakeMessage(payload) if err != nil { return fmt.Errorf("error sending handshake message: %s", err) } // stage 2 // - - // read message plaintext, err = s.readHandshakeMessage() if err != nil { return fmt.Errorf("error reading handshake message: %s", err) } - - // unmarshal payload - err = proto.Unmarshal(plaintext, nhp) + err = s.handleRemoteHandshakePayload(plaintext) if err != nil { - return fmt.Errorf("error unmarshaling remote handshake payload: %s", err) - } - - // set remote libp2p public key - err = s.setRemotePeerInfo(nhp.GetIdentityKey()) - if err != nil { - return fmt.Errorf("error processing remote libp2p key: %s", err) - } - - // verify payload is signed by libp2p key - err = s.verifyPayload(nhp, s.ns.hs.PeerStatic()) - if err != nil { - return fmt.Errorf("error validating handshake signature: %s", err) + return err } }