265 lines
8.0 KiB
Go
Raw Normal View History

package noise
import (
"context"
"crypto/rand"
"encoding/binary"
"fmt"
pool "github.com/libp2p/go-buffer-pool"
"golang.org/x/crypto/poly1305"
"time"
"github.com/flynn/noise"
"github.com/gogo/protobuf/proto"
"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-noise/pb"
)
// payloadSigPrefix is prepended to our Noise static key before signing with
// our libp2p identity key.
const payloadSigPrefix = "noise-libp2p-static-key:"
// All noise session share a fixed cipher suite
var cipherSuite = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
// runHandshake exchanges handshake messages with the remote peer to establish
// a noise-libp2p session. It blocks until the handshake completes or fails.
func (s *secureSession) runHandshake(ctx context.Context) error {
kp, err := noise.DH25519.GenerateKeypair(rand.Reader)
if err != nil {
return fmt.Errorf("error generating static keypair: %w", err)
}
cfg := noise.Config{
CipherSuite: cipherSuite,
Pattern: noise.HandshakeXX,
Initiator: s.initiator,
StaticKeypair: kp,
}
hs, err := noise.NewHandshakeState(cfg)
if err != nil {
return fmt.Errorf("error initializing handshake state: %w", err)
}
payload, err := s.generateHandshakePayload(kp)
if err != nil {
return err
}
// set a deadline to complete the handshake, if one has been supplied.
// clear it after we're done.
if deadline, ok := ctx.Deadline(); ok {
if err := s.SetDeadline(deadline); err == nil {
// schedule the deadline removal once we're done handshaking.
defer s.SetDeadline(time.Time{})
}
}
// We can re-use this buffer for all handshake messages as it's size
// will be the size of the maximum handshake message for the Noise XX pattern.
// Also, since we prefix every noise handshake message with it's length, we need to account for
// it when we fetch the buffer from the pool
maxMsgSize := 2*noise.DH25519.DHLen() + len(payload) + 2*poly1305.TagSize
hbuf := pool.Get(maxMsgSize + LengthPrefixLength)
defer pool.Put(hbuf)
if s.initiator {
// stage 0 //
// do not send the payload just yet, as it would be plaintext; not secret.
// Handshake Msg Len = len(DH ephemeral key)
err = s.sendHandshakeMessage(hs, nil, hbuf)
if err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
// stage 1 //
plaintext, err := s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
return err
}
// stage 2 //
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
err = s.sendHandshakeMessage(hs, payload, hbuf)
if err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
} else {
// stage 0 //
plaintext, err := s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
// stage 1 //
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
//MAC(payload is encrypted)
err = s.sendHandshakeMessage(hs, payload, hbuf)
if err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
// stage 2 //
plaintext, err = s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
return err
}
}
return nil
}
// setCipherStates sets the initial cipher states that will be used to protect
// traffic after the handshake.
//
// It is called when the final handshake message is processed by
// either sendHandshakeMessage or readHandshakeMessage.
func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) {
if s.initiator {
s.enc = cs1
s.dec = cs2
} else {
s.enc = cs2
s.dec = cs1
}
}
// sendHandshakeMessage sends the next handshake message in the sequence.
//
// If payload is non-empty, it will be included in the handshake message.
// If this is the final message in the sequence, calls setCipherStates
// to initialize cipher states.
func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload []byte, hbuf []byte) error {
// the first two bytes will be the length of the noise handshake message.
bz, cs1, cs2, err := hs.WriteMessage(hbuf[:LengthPrefixLength], payload)
if err != nil {
return err
}
// bz will also include the length prefix as we passed a slice of LengthPrefixLength length
// to hs.Write().
binary.BigEndian.PutUint16(bz, uint16(len(bz)-LengthPrefixLength))
_, err = s.writeMsgInsecure(bz)
if err != nil {
return err
}
if cs1 != nil && cs2 != nil {
s.setCipherStates(cs1, cs2)
}
return nil
}
// readHandshakeMessage reads a message from the insecure conn and tries to
// process it as the expected next message in the handshake sequence.
//
// If the message contains a payload, it will be decrypted and returned.
//
// If this is the final message in the sequence, it calls setCipherStates
// to initialize cipher states.
func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte, error) {
l, err := s.readNextInsecureMsgLen()
if err != nil {
return nil, err
}
buf := pool.Get(l)
defer pool.Put(buf)
if err := s.readNextMsgInsecure(buf); err != nil {
return nil, err
}
msg, cs1, cs2, err := hs.ReadMessage(nil, buf)
if err != nil {
return nil, err
}
if cs1 != nil && cs2 != nil {
s.setCipherStates(cs1, cs2)
}
return msg, nil
}
// generateHandshakePayload creates a libp2p handshake payload with a
// signature of our static noise key.
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byte, error) {
// obtain the public key from the handshake session so we can sign it with
// our libp2p secret key.
localKeyRaw, err := s.LocalPublicKey().Bytes()
if err != nil {
return nil, fmt.Errorf("error serializing libp2p identity key: %w", err)
}
// prepare payload to sign; perform signature.
toSign := append([]byte(payloadSigPrefix), localStatic.Public...)
signedPayload, err := s.localKey.Sign(toSign)
if err != nil {
return nil, fmt.Errorf("error sigining handshake payload: %w", err)
}
// create payload
payload := new(pb.NoiseHandshakePayload)
payload.IdentityKey = localKeyRaw
payload.IdentitySig = signedPayload
payloadEnc, err := proto.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling handshake payload: %w", err)
}
return payloadEnc, nil
}
// handleRemoteHandshakePayload unmarshals the handshake payload object sent
// by the remote peer and validates the signature against the peer's static Noise key.
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) error {
// unmarshal payload
nhp := new(pb.NoiseHandshakePayload)
err := proto.Unmarshal(payload, nhp)
if err != nil {
return fmt.Errorf("error unmarshaling remote handshake payload: %w", err)
}
// unpack remote peer's public libp2p key
remotePubKey, err := crypto.UnmarshalPublicKey(nhp.GetIdentityKey())
if err != nil {
return err
}
id, err := peer.IDFromPublicKey(remotePubKey)
if err != nil {
return err
}
// if we know who we're trying to reach, make sure we have the right peer
if s.initiator && s.remoteID != id {
// use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms.
return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
}
// verify payload is signed by asserted remote libp2p key.
sig := nhp.GetIdentitySig()
msg := append([]byte(payloadSigPrefix), remoteStatic...)
ok, err := remotePubKey.Verify(msg, sig)
if err != nil {
return fmt.Errorf("error verifying signature: %w", err)
} else if !ok {
return fmt.Errorf("handshake signature invalid")
}
// set remote peer key and id
s.remoteID = id
s.remoteKey = remotePubKey
return nil
}