factor out processing remote handshake payload
This commit is contained in:
parent
70efae2bed
commit
5974ecc852
|
@ -119,13 +119,34 @@ func (s *secureSession) generateHandshakePayload() ([]byte, error) {
|
|||
return payloadEnc, nil
|
||||
}
|
||||
|
||||
func (s *secureSession) handleRemoteHandshakePayload(payload []byte) error {
|
||||
// unmarshal payload
|
||||
nhp := new(pb.NoiseHandshakePayload)
|
||||
err := proto.Unmarshal(payload, nhp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling remote handshake payload: %s", err)
|
||||
}
|
||||
|
||||
// set remote libp2p public key
|
||||
err = s.setRemotePeerInfo(nhp.GetIdentityKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error processing remote libp2p key: %s", err)
|
||||
}
|
||||
|
||||
// verify payload is signed by libp2p key
|
||||
err = s.verifyPayload(nhp, s.ns.hs.PeerStatic())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error validating handshake signature: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Runs the XX handshake
|
||||
// XX:
|
||||
// -> e
|
||||
// <- e, ee, s, es
|
||||
// -> s, se
|
||||
func (s *secureSession) runHandshake(ctx context.Context) (err error) {
|
||||
|
||||
cfg := noise.Config{
|
||||
CipherSuite: cipherSuite,
|
||||
Pattern: noise.HandshakeXX,
|
||||
|
@ -146,88 +167,47 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
|
|||
|
||||
if s.ns.initiator {
|
||||
// stage 0 //
|
||||
|
||||
err = s.sendHandshakeMessage(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %s", err)
|
||||
}
|
||||
|
||||
// stage 1 //
|
||||
plaintext, err := s.readHandshakeMessage()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading handshake message: %s", err)
|
||||
}
|
||||
err = s.handleRemoteHandshakePayload(plaintext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// read reply
|
||||
// stage 2 //
|
||||
err = s.sendHandshakeMessage(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %s", err)
|
||||
}
|
||||
} else {
|
||||
// stage 0 //
|
||||
plaintext, err := s.readHandshakeMessage()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading handshake message: %s", err)
|
||||
}
|
||||
|
||||
// stage 2 //
|
||||
err = s.sendHandshakeMessage(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %s", err)
|
||||
}
|
||||
|
||||
// unmarshal payload
|
||||
nhp := new(pb.NoiseHandshakePayload)
|
||||
err = proto.Unmarshal(plaintext, nhp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling remote handshake payload: %s", err)
|
||||
}
|
||||
|
||||
// set remote libp2p public key
|
||||
err = s.setRemotePeerInfo(nhp.GetIdentityKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error processing remote libp2p key: %s", err)
|
||||
}
|
||||
|
||||
// verify payload is signed by libp2p key
|
||||
err = s.verifyPayload(nhp, s.ns.hs.PeerStatic())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error validating handshake signature: %s", err)
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
// stage 0 //
|
||||
var plaintext []byte
|
||||
nhp := new(pb.NoiseHandshakePayload)
|
||||
|
||||
// read message
|
||||
plaintext, err = s.readHandshakeMessage()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading handshake message: %s", err)
|
||||
}
|
||||
|
||||
// stage 1 //
|
||||
|
||||
err = s.sendHandshakeMessage(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %s", err)
|
||||
}
|
||||
|
||||
// stage 2 //
|
||||
|
||||
// read message
|
||||
plaintext, err = s.readHandshakeMessage()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading handshake message: %s", err)
|
||||
}
|
||||
|
||||
// unmarshal payload
|
||||
err = proto.Unmarshal(plaintext, nhp)
|
||||
err = s.handleRemoteHandshakePayload(plaintext)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling remote handshake payload: %s", err)
|
||||
}
|
||||
|
||||
// set remote libp2p public key
|
||||
err = s.setRemotePeerInfo(nhp.GetIdentityKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error processing remote libp2p key: %s", err)
|
||||
}
|
||||
|
||||
// verify payload is signed by libp2p key
|
||||
err = s.verifyPayload(nhp, s.ns.hs.PeerStatic())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error validating handshake signature: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue