359 lines
8.7 KiB
Go
359 lines
8.7 KiB
Go
package noise
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
proto "github.com/gogo/protobuf/proto"
|
|
logging "github.com/ipfs/go-log"
|
|
"github.com/libp2p/go-libp2p-core/crypto"
|
|
"github.com/libp2p/go-libp2p-core/peer"
|
|
|
|
ik "github.com/libp2p/go-libp2p-noise/ik"
|
|
pb "github.com/libp2p/go-libp2p-noise/pb"
|
|
xx "github.com/libp2p/go-libp2p-noise/xx"
|
|
)
|
|
|
|
const payload_string = "noise-libp2p-static-key:"
|
|
|
|
// Each encrypted transport message must be <= 65,535 bytes, including 16
|
|
// bytes of authentication data. To write larger plaintexts, we split them
|
|
// into fragments of maxPlaintextLength before encrypting.
|
|
const maxPlaintextLength = 65519
|
|
|
|
var log = logging.Logger("noise")
|
|
|
|
var errNoKeypair = errors.New("cannot initiate secureSession - transport has no noise keypair")
|
|
|
|
type secureSession struct {
|
|
insecure net.Conn
|
|
|
|
initiator bool
|
|
prologue []byte
|
|
|
|
localKey crypto.PrivKey
|
|
localPeer peer.ID
|
|
remotePeer peer.ID
|
|
|
|
local peerInfo
|
|
remote peerInfo
|
|
|
|
xx_ns *xx.NoiseSession
|
|
ik_ns *ik.NoiseSession
|
|
|
|
xx_complete bool
|
|
ik_complete bool
|
|
|
|
noisePipesSupport bool
|
|
noiseStaticKeyCache *KeyCache
|
|
|
|
noiseKeypair *Keypair
|
|
|
|
msgBuffer []byte
|
|
rwLock sync.Mutex
|
|
}
|
|
|
|
type peerInfo struct {
|
|
noiseKey [32]byte // static noise public key
|
|
libp2pKey crypto.PubKey
|
|
}
|
|
|
|
// newSecureSession creates a noise session over the given insecure Conn, using the static
|
|
// Noise keypair and libp2p identity keypair from the given Transport.
|
|
//
|
|
// If tpt.noisePipesSupport == true, the Noise Pipes handshake protocol will be used,
|
|
// which consists of the IK and XXfallback handshake patterns. With Noise Pipes on, we first try IK,
|
|
// if that fails, move to XXfallback. With Noise Pipes off, we always do XX.
|
|
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, initiator bool) (*secureSession, error) {
|
|
if tpt.noiseKeypair == nil {
|
|
return nil, errNoKeypair
|
|
}
|
|
|
|
// if the transport doesn't have a key cache, we make a new one just for
|
|
// this session. it's a bit of a waste, but saves us having to check if
|
|
// it's nil later
|
|
keyCache := tpt.noiseStaticKeyCache
|
|
if keyCache == nil {
|
|
keyCache = NewKeyCache()
|
|
}
|
|
|
|
localPeerInfo := peerInfo{
|
|
noiseKey: tpt.noiseKeypair.publicKey,
|
|
libp2pKey: tpt.privateKey.GetPublic(),
|
|
}
|
|
|
|
s := &secureSession{
|
|
insecure: insecure,
|
|
initiator: initiator,
|
|
prologue: []byte{},
|
|
localKey: tpt.privateKey,
|
|
localPeer: tpt.localID,
|
|
remotePeer: remote,
|
|
local: localPeerInfo,
|
|
noisePipesSupport: tpt.noisePipesSupport,
|
|
noiseStaticKeyCache: keyCache,
|
|
msgBuffer: []byte{},
|
|
noiseKeypair: tpt.noiseKeypair,
|
|
}
|
|
|
|
err := s.runHandshake(ctx)
|
|
|
|
return s, err
|
|
}
|
|
|
|
func (s *secureSession) NoiseStaticKeyCache() *KeyCache {
|
|
return s.noiseStaticKeyCache
|
|
}
|
|
|
|
func (s *secureSession) NoisePublicKey() [32]byte {
|
|
return s.noiseKeypair.publicKey
|
|
}
|
|
|
|
func (s *secureSession) NoisePrivateKey() [32]byte {
|
|
return s.noiseKeypair.privateKey
|
|
}
|
|
|
|
func (s *secureSession) readLength() (int, error) {
|
|
buf := make([]byte, 2)
|
|
_, err := s.insecure.Read(buf)
|
|
return int(binary.BigEndian.Uint16(buf)), err
|
|
}
|
|
|
|
func (s *secureSession) writeLength(length int) error {
|
|
buf := make([]byte, 2)
|
|
binary.BigEndian.PutUint16(buf, uint16(length))
|
|
_, err := s.insecure.Write(buf)
|
|
return err
|
|
}
|
|
|
|
func (s *secureSession) setRemotePeerInfo(key []byte) (err error) {
|
|
s.remote.libp2pKey, err = crypto.UnmarshalPublicKey(key)
|
|
return err
|
|
}
|
|
|
|
func (s *secureSession) setRemotePeerID(key crypto.PubKey) (err error) {
|
|
s.remotePeer, err = peer.IDFromPublicKey(key)
|
|
return err
|
|
}
|
|
|
|
func (s *secureSession) verifyPayload(payload *pb.NoiseHandshakePayload, noiseKey [32]byte) (err error) {
|
|
sig := payload.GetIdentitySig()
|
|
msg := append([]byte(payload_string), noiseKey[:]...)
|
|
|
|
log.Debugf("verifyPayload msg=%x", msg)
|
|
|
|
ok, err := s.RemotePublicKey().Verify(msg, sig)
|
|
if err != nil {
|
|
return err
|
|
} else if !ok {
|
|
return fmt.Errorf("did not verify payload")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *secureSession) runHandshake(ctx context.Context) error {
|
|
// setup libp2p keys
|
|
localKeyRaw, err := s.LocalPublicKey().Bytes()
|
|
if err != nil {
|
|
return fmt.Errorf("runHandshake err getting raw pubkey: %s", err)
|
|
}
|
|
|
|
// sign noise data for payload
|
|
noise_pub := s.noiseKeypair.publicKey
|
|
signedPayload, err := s.localKey.Sign(append([]byte(payload_string), noise_pub[:]...))
|
|
if err != nil {
|
|
log.Errorf("runHandshake signing payload err=%s", err)
|
|
return fmt.Errorf("runHandshake signing payload err=%s", err)
|
|
}
|
|
|
|
// create payload
|
|
payload := new(pb.NoiseHandshakePayload)
|
|
payload.IdentityKey = localKeyRaw
|
|
payload.IdentitySig = signedPayload
|
|
payloadEnc, err := proto.Marshal(payload)
|
|
if err != nil {
|
|
log.Errorf("runHandshake marshal payload err=%s", err)
|
|
return fmt.Errorf("runHandshake proto marshal payload err=%s", err)
|
|
}
|
|
|
|
// If we support Noise pipes, we try IK first, falling back to XX if IK fails.
|
|
// The exception is when we're the initiator and don't know the other party's
|
|
// static Noise key. Then IK will always fail, so we go straight to XX.
|
|
tryIK := s.noisePipesSupport
|
|
if s.initiator && s.noiseStaticKeyCache.Load(s.remotePeer) == [32]byte{} {
|
|
tryIK = false
|
|
}
|
|
if tryIK {
|
|
// we're either a responder or an initiator with a known static key for the remote peer, try IK
|
|
buf, err := s.runHandshake_ik(ctx, payloadEnc)
|
|
if err != nil {
|
|
log.Error("runHandshake ik err=%s", err)
|
|
|
|
// IK failed, pipe to XXfallback
|
|
err = s.runHandshake_xx(ctx, true, payloadEnc, buf)
|
|
if err != nil {
|
|
log.Error("runHandshake xx err=err", err)
|
|
return fmt.Errorf("runHandshake xx err=%s", err)
|
|
}
|
|
|
|
s.xx_complete = true
|
|
} else {
|
|
s.ik_complete = true
|
|
}
|
|
} else {
|
|
// unknown static key for peer, try XX
|
|
err := s.runHandshake_xx(ctx, false, payloadEnc, nil)
|
|
if err != nil {
|
|
log.Error("runHandshake xx err=%s", err)
|
|
return err
|
|
}
|
|
|
|
s.xx_complete = true
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *secureSession) LocalAddr() net.Addr {
|
|
return s.insecure.LocalAddr()
|
|
}
|
|
|
|
func (s *secureSession) LocalPeer() peer.ID {
|
|
return s.localPeer
|
|
}
|
|
|
|
func (s *secureSession) LocalPrivateKey() crypto.PrivKey {
|
|
return s.localKey
|
|
}
|
|
|
|
func (s *secureSession) LocalPublicKey() crypto.PubKey {
|
|
return s.localKey.GetPublic()
|
|
}
|
|
|
|
func (s *secureSession) Read(buf []byte) (int, error) {
|
|
l := len(buf)
|
|
|
|
// if we have previously unread bytes, and they fit into the buf, copy them over and return
|
|
if l <= len(s.msgBuffer) {
|
|
copy(buf, s.msgBuffer)
|
|
s.msgBuffer = s.msgBuffer[l:]
|
|
return l, nil
|
|
}
|
|
|
|
readChunk := func(buf []byte) (int, error) {
|
|
// read length of encrypted message
|
|
l, err := s.readLength()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// read and decrypt ciphertext
|
|
ciphertext := make([]byte, l)
|
|
_, err = s.insecure.Read(ciphertext)
|
|
if err != nil {
|
|
log.Error("read ciphertext err", err)
|
|
return 0, err
|
|
}
|
|
|
|
plaintext, err := s.Decrypt(ciphertext)
|
|
if err != nil {
|
|
log.Error("decrypt err", err)
|
|
return 0, err
|
|
}
|
|
|
|
// append plaintext to message buffer, copy over what can fit in the buf
|
|
// then advance message buffer to remove what was copied
|
|
s.msgBuffer = append(s.msgBuffer, plaintext...)
|
|
c := copy(buf, s.msgBuffer)
|
|
s.msgBuffer = s.msgBuffer[c:]
|
|
return c, nil
|
|
}
|
|
|
|
total := 0
|
|
for i := 0; i < len(buf); i += maxPlaintextLength {
|
|
end := i + maxPlaintextLength
|
|
if end > len(buf) {
|
|
end = len(buf)
|
|
}
|
|
|
|
c, err := readChunk(buf[i:end])
|
|
total += c
|
|
if err != nil {
|
|
return total, err
|
|
}
|
|
}
|
|
|
|
return total, nil
|
|
}
|
|
|
|
func (s *secureSession) RemoteAddr() net.Addr {
|
|
return s.insecure.RemoteAddr()
|
|
}
|
|
|
|
func (s *secureSession) RemotePeer() peer.ID {
|
|
return s.remotePeer
|
|
}
|
|
|
|
func (s *secureSession) RemotePublicKey() crypto.PubKey {
|
|
return s.remote.libp2pKey
|
|
}
|
|
|
|
func (s *secureSession) SetDeadline(t time.Time) error {
|
|
return s.insecure.SetDeadline(t)
|
|
}
|
|
|
|
func (s *secureSession) SetReadDeadline(t time.Time) error {
|
|
return s.insecure.SetReadDeadline(t)
|
|
}
|
|
|
|
func (s *secureSession) SetWriteDeadline(t time.Time) error {
|
|
return s.insecure.SetWriteDeadline(t)
|
|
}
|
|
|
|
func (s *secureSession) Write(in []byte) (int, error) {
|
|
s.rwLock.Lock()
|
|
defer s.rwLock.Unlock()
|
|
|
|
writeChunk := func(in []byte) (int, error) {
|
|
ciphertext, err := s.Encrypt(in)
|
|
if err != nil {
|
|
log.Error("encrypt error", err)
|
|
return 0, err
|
|
}
|
|
|
|
err = s.writeLength(len(ciphertext))
|
|
if err != nil {
|
|
log.Error("write length err", err)
|
|
return 0, err
|
|
}
|
|
|
|
_, err = s.insecure.Write(ciphertext)
|
|
return len(in), err
|
|
}
|
|
|
|
written := 0
|
|
for i := 0; i < len(in); i += maxPlaintextLength {
|
|
end := i + maxPlaintextLength
|
|
if end > len(in) {
|
|
end = len(in)
|
|
}
|
|
|
|
l, err := writeChunk(in[i:end])
|
|
written += l
|
|
if err != nil {
|
|
return written, err
|
|
}
|
|
}
|
|
return written, nil
|
|
}
|
|
|
|
func (s *secureSession) Close() error {
|
|
return s.insecure.Close()
|
|
}
|