factor out processing remote handshake payload

This commit is contained in:
Yusef Napora 2020-03-02 16:33:28 -05:00
parent 70efae2bed
commit 5974ecc852
1 changed files with 39 additions and 59 deletions

View File

@ -119,13 +119,34 @@ func (s *secureSession) generateHandshakePayload() ([]byte, error) {
return payloadEnc, nil
}
func (s *secureSession) handleRemoteHandshakePayload(payload []byte) error {
// unmarshal payload
nhp := new(pb.NoiseHandshakePayload)
err := proto.Unmarshal(payload, nhp)
if err != nil {
return fmt.Errorf("error unmarshaling remote handshake payload: %s", err)
}
// set remote libp2p public key
err = s.setRemotePeerInfo(nhp.GetIdentityKey())
if err != nil {
return fmt.Errorf("error processing remote libp2p key: %s", err)
}
// verify payload is signed by libp2p key
err = s.verifyPayload(nhp, s.ns.hs.PeerStatic())
if err != nil {
return fmt.Errorf("error validating handshake signature: %s", err)
}
return nil
}
// Runs the XX handshake
// XX:
// -> e
// <- e, ee, s, es
// -> s, se
func (s *secureSession) runHandshake(ctx context.Context) (err error) {
cfg := noise.Config{
CipherSuite: cipherSuite,
Pattern: noise.HandshakeXX,
@ -146,88 +167,47 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if s.ns.initiator {
// stage 0 //
err = s.sendHandshakeMessage(nil)
if err != nil {
return fmt.Errorf("error sending handshake message: %s", err)
}
// stage 1 //
plaintext, err := s.readHandshakeMessage()
if err != nil {
return fmt.Errorf("error reading handshake message: %s", err)
}
err = s.handleRemoteHandshakePayload(plaintext)
if err != nil {
return err
}
// read reply
// stage 2 //
err = s.sendHandshakeMessage(payload)
if err != nil {
return fmt.Errorf("error sending handshake message: %s", err)
}
} else {
// stage 0 //
plaintext, err := s.readHandshakeMessage()
if err != nil {
return fmt.Errorf("error reading handshake message: %s", err)
}
// stage 2 //
err = s.sendHandshakeMessage(payload)
if err != nil {
return fmt.Errorf("error sending handshake message: %s", err)
}
// unmarshal payload
nhp := new(pb.NoiseHandshakePayload)
err = proto.Unmarshal(plaintext, nhp)
if err != nil {
return fmt.Errorf("error unmarshaling remote handshake payload: %s", err)
}
// set remote libp2p public key
err = s.setRemotePeerInfo(nhp.GetIdentityKey())
if err != nil {
return fmt.Errorf("error processing remote libp2p key: %s", err)
}
// verify payload is signed by libp2p key
err = s.verifyPayload(nhp, s.ns.hs.PeerStatic())
if err != nil {
return fmt.Errorf("error validating handshake signature: %s", err)
}
} else {
// stage 0 //
var plaintext []byte
nhp := new(pb.NoiseHandshakePayload)
// read message
plaintext, err = s.readHandshakeMessage()
if err != nil {
return fmt.Errorf("error reading handshake message: %s", err)
}
// stage 1 //
err = s.sendHandshakeMessage(payload)
if err != nil {
return fmt.Errorf("error sending handshake message: %s", err)
}
// stage 2 //
// read message
plaintext, err = s.readHandshakeMessage()
if err != nil {
return fmt.Errorf("error reading handshake message: %s", err)
}
// unmarshal payload
err = proto.Unmarshal(plaintext, nhp)
err = s.handleRemoteHandshakePayload(plaintext)
if err != nil {
return fmt.Errorf("error unmarshaling remote handshake payload: %s", err)
}
// set remote libp2p public key
err = s.setRemotePeerInfo(nhp.GetIdentityKey())
if err != nil {
return fmt.Errorf("error processing remote libp2p key: %s", err)
}
// verify payload is signed by libp2p key
err = s.verifyPayload(nhp, s.ns.hs.PeerStatic())
if err != nil {
return fmt.Errorf("error validating handshake signature: %s", err)
return err
}
}