Merge pull request #58 from libp2p/raul-review
This commit is contained in:
commit
9b9c06f042
|
@ -124,6 +124,10 @@ func pipeRandom(src rand.Source, w io.WriteCloser, r io.Reader, n int64) error {
|
|||
func benchDataTransfer(b *benchenv, size int64) {
|
||||
var totalBytes int64
|
||||
var totalTime time.Duration
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
initSession, respSession := b.connect(true)
|
||||
|
||||
|
@ -153,6 +157,9 @@ func BenchmarkTransfer500Mb(b *testing.B) {
|
|||
}
|
||||
|
||||
func (b benchenv) benchHandshake() {
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
i, r := b.connect(false)
|
||||
b.StopTimer()
|
||||
|
|
|
@ -1,22 +1,46 @@
|
|||
package noise
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
func (s *secureSession) encrypt(plaintext []byte) (ciphertext []byte, err error) {
|
||||
// encrypt calls the cipher's encryption. It encrypts the provided plaintext,
|
||||
// slice-appending the ciphertext on out.
|
||||
//
|
||||
// Usually you want to pass a 0-len slice to this method, with enough capacity
|
||||
// to accommodate the ciphertext in order to spare allocs.
|
||||
//
|
||||
// encrypt returns a new slice header, whose len is the length of the resulting
|
||||
// ciphertext, including the authentication tag.
|
||||
//
|
||||
// This method will not allocate if the supplied slice is large enough to
|
||||
// accommodate the encrypted data + authentication tag. If so, the returned
|
||||
// slice header should be a view of the original slice.
|
||||
//
|
||||
// With the poly1305 MAC function that noise-libp2p uses, the authentication tag
|
||||
// adds an overhead of 16 bytes.
|
||||
func (s *secureSession) encrypt(out, plaintext []byte) ([]byte, error) {
|
||||
if s.enc == nil {
|
||||
return nil, errors.New("cannot encrypt, handshake incomplete")
|
||||
}
|
||||
|
||||
// TODO: use pre-allocated buffers
|
||||
ciphertext = s.enc.Encrypt(nil, nil, plaintext)
|
||||
return ciphertext, nil
|
||||
return s.enc.Encrypt(out, nil, plaintext), nil
|
||||
}
|
||||
|
||||
func (s *secureSession) decrypt(ciphertext []byte) (plaintext []byte, err error) {
|
||||
// decrypt calls the cipher's decryption. It decrypts the provided ciphertext,
|
||||
// slice-appending the plaintext on out.
|
||||
//
|
||||
// Usually you want to pass a 0-len slice to this method, with enough capacity
|
||||
// to accommodate the plaintext in order to spare allocs.
|
||||
//
|
||||
// decrypt returns a new slice header, whose len is the length of the resulting
|
||||
// plaintext, without the authentication tag.
|
||||
//
|
||||
// This method will not allocate if the supplied slice is large enough to
|
||||
// accommodate the plaintext. If so, the returned slice header should be a view
|
||||
// of the original slice.
|
||||
func (s *secureSession) decrypt(out, ciphertext []byte) ([]byte, error) {
|
||||
if s.dec == nil {
|
||||
return nil, errors.New("cannot decrypt, handshake incomplete")
|
||||
}
|
||||
|
||||
// TODO: use pre-allocated buffers
|
||||
return s.dec.Decrypt(nil, nil, ciphertext)
|
||||
return s.dec.Decrypt(out, nil, ciphertext)
|
||||
}
|
||||
|
|
|
@ -18,12 +18,12 @@ func TestEncryptAndDecrypt_InitToResp(t *testing.T) {
|
|||
defer respConn.Close()
|
||||
|
||||
plaintext := []byte("helloworld")
|
||||
ciphertext, err := initConn.encrypt(plaintext)
|
||||
ciphertext, err := initConn.encrypt(nil, plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err := respConn.decrypt(ciphertext)
|
||||
result, err := respConn.decrypt(nil, ciphertext)
|
||||
if !bytes.Equal(plaintext, result) {
|
||||
t.Fatalf("got %x expected %x", result, plaintext)
|
||||
} else if err != nil {
|
||||
|
@ -31,12 +31,12 @@ func TestEncryptAndDecrypt_InitToResp(t *testing.T) {
|
|||
}
|
||||
|
||||
plaintext = []byte("goodbye")
|
||||
ciphertext, err = initConn.encrypt(plaintext)
|
||||
ciphertext, err = initConn.encrypt(nil, plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err = respConn.decrypt(ciphertext)
|
||||
result, err = respConn.decrypt(nil, ciphertext)
|
||||
if !bytes.Equal(plaintext, result) {
|
||||
t.Fatalf("got %x expected %x", result, plaintext)
|
||||
} else if err != nil {
|
||||
|
@ -53,12 +53,12 @@ func TestEncryptAndDecrypt_RespToInit(t *testing.T) {
|
|||
defer respConn.Close()
|
||||
|
||||
plaintext := []byte("helloworld")
|
||||
ciphertext, err := respConn.encrypt(plaintext)
|
||||
ciphertext, err := respConn.encrypt(nil, plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err := initConn.decrypt(ciphertext)
|
||||
result, err := initConn.decrypt(nil, ciphertext)
|
||||
if !bytes.Equal(plaintext, result) {
|
||||
t.Fatalf("got %x expected %x", result, plaintext)
|
||||
} else if err != nil {
|
||||
|
@ -75,14 +75,14 @@ func TestCryptoFailsIfCiphertextIsAltered(t *testing.T) {
|
|||
defer respConn.Close()
|
||||
|
||||
plaintext := []byte("helloworld")
|
||||
ciphertext, err := respConn.encrypt(plaintext)
|
||||
ciphertext, err := respConn.encrypt(nil, plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ciphertext[0] = ^ciphertext[0]
|
||||
|
||||
_, err = initConn.decrypt(ciphertext)
|
||||
_, err = initConn.decrypt(nil, ciphertext)
|
||||
if err == nil {
|
||||
t.Fatal("expected decryption to fail when ciphertext altered")
|
||||
}
|
||||
|
@ -94,11 +94,11 @@ func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) {
|
|||
_ = resp.Close()
|
||||
|
||||
session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", true)
|
||||
_, err := session.encrypt([]byte("hi"))
|
||||
_, err := session.encrypt(nil, []byte("hi"))
|
||||
if err == nil {
|
||||
t.Error("expected encryption error when handshake incomplete")
|
||||
}
|
||||
_, err = session.decrypt([]byte("it's a secret"))
|
||||
_, err = session.decrypt(nil, []byte("it's a secret"))
|
||||
if err == nil {
|
||||
t.Error("expected decryption error when handshake incomplete")
|
||||
}
|
||||
|
|
Binary file not shown.
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
|
@ -26,7 +27,7 @@ var cipherSuite = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, no
|
|||
func (s *secureSession) runHandshake(ctx context.Context) error {
|
||||
kp, err := noise.DH25519.GenerateKeypair(rand.Reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating static keypair: %s", err)
|
||||
return fmt.Errorf("error generating static keypair: %w", err)
|
||||
}
|
||||
|
||||
cfg := noise.Config{
|
||||
|
@ -38,7 +39,7 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
|||
|
||||
hs, err := noise.NewHandshakeState(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error initializing handshake state: %s", err)
|
||||
return fmt.Errorf("error initializing handshake state: %w", err)
|
||||
}
|
||||
|
||||
payload, err := s.generateHandshakePayload(kp)
|
||||
|
@ -46,17 +47,30 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// set a deadline to complete the handshake, if one has been supplied.
|
||||
// clear it after we're done.
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := s.SetDeadline(deadline); err == nil {
|
||||
// schedule the deadline removal once we're done handshaking.
|
||||
defer s.SetDeadline(time.Time{})
|
||||
}
|
||||
// TODO: else case (transport doesn't support native timeouts); spin off
|
||||
// a goroutine to monitor the context cancellation and pull the rug
|
||||
// from under by closing the connection altogether.
|
||||
}
|
||||
|
||||
if s.initiator {
|
||||
// stage 0 //
|
||||
// do not send the payload just yet, as it would be plaintext; not secret.
|
||||
err = s.sendHandshakeMessage(hs, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %s", err)
|
||||
return fmt.Errorf("error sending handshake message: %w", err)
|
||||
}
|
||||
|
||||
// stage 1 //
|
||||
plaintext, err := s.readHandshakeMessage(hs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading handshake message: %s", err)
|
||||
return fmt.Errorf("error reading handshake message: %w", err)
|
||||
}
|
||||
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
|
||||
if err != nil {
|
||||
|
@ -66,25 +80,25 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
|||
// stage 2 //
|
||||
err = s.sendHandshakeMessage(hs, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %s", err)
|
||||
return fmt.Errorf("error sending handshake message: %w", err)
|
||||
}
|
||||
} else {
|
||||
// stage 0 //
|
||||
plaintext, err := s.readHandshakeMessage(hs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading handshake message: %s", err)
|
||||
return fmt.Errorf("error reading handshake message: %w", err)
|
||||
}
|
||||
|
||||
// stage 1 //
|
||||
err = s.sendHandshakeMessage(hs, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending handshake message: %s", err)
|
||||
return fmt.Errorf("error sending handshake message: %w", err)
|
||||
}
|
||||
|
||||
// stage 2 //
|
||||
plaintext, err = s.readHandshakeMessage(hs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading handshake message: %s", err)
|
||||
return fmt.Errorf("error reading handshake message: %w", err)
|
||||
}
|
||||
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
|
||||
if err != nil {
|
||||
|
@ -95,9 +109,11 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// setCipherStates is called when the final handshake message is processed by
|
||||
// setCipherStates sets the initial cipher states that will be used to protect
|
||||
// traffic after the handshake.
|
||||
//
|
||||
// It is called when the final handshake message is processed by
|
||||
// either sendHandshakeMessage or readHandshakeMessage.
|
||||
// It sets the initial cipher states that will be used to protect traffic after the handshake.
|
||||
func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) {
|
||||
if s.initiator {
|
||||
s.enc = cs1
|
||||
|
@ -109,6 +125,7 @@ func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) {
|
|||
}
|
||||
|
||||
// sendHandshakeMessage sends the next handshake message in the sequence.
|
||||
//
|
||||
// If payload is non-empty, it will be included in the handshake message.
|
||||
// If this is the final message in the sequence, calls setCipherStates
|
||||
// to initialize cipher states.
|
||||
|
@ -118,7 +135,7 @@ func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload [
|
|||
return err
|
||||
}
|
||||
|
||||
err = s.writeMsgInsecure(buf)
|
||||
_, err = s.writeMsgInsecure(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -131,8 +148,10 @@ func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload [
|
|||
|
||||
// readHandshakeMessage reads a message from the insecure conn and tries to
|
||||
// process it as the expected next message in the handshake sequence.
|
||||
//
|
||||
// If the message contains a payload, it will be decrypted and returned.
|
||||
// If this is the final message in the sequence, calls setCipherStates
|
||||
//
|
||||
// If this is the final message in the sequence, it calls setCipherStates
|
||||
// to initialize cipher states.
|
||||
func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte, error) {
|
||||
raw, err := s.readMsgInsecure()
|
||||
|
@ -152,17 +171,18 @@ func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte,
|
|||
// generateHandshakePayload creates a libp2p handshake payload with a
|
||||
// signature of our static noise key.
|
||||
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byte, error) {
|
||||
// setup libp2p keys
|
||||
// obtain the public key from the handshake session so we can sign it with
|
||||
// our libp2p secret key.
|
||||
localKeyRaw, err := s.LocalPublicKey().Bytes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error serializing libp2p identity key: %s", err)
|
||||
return nil, fmt.Errorf("error serializing libp2p identity key: %w", err)
|
||||
}
|
||||
|
||||
// sign noise data for payload
|
||||
// prepare payload to sign; perform signature.
|
||||
toSign := append([]byte(payloadSigPrefix), localStatic.Public...)
|
||||
signedPayload, err := s.localKey.Sign(toSign)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error sigining handshake payload: %s", err)
|
||||
return nil, fmt.Errorf("error sigining handshake payload: %w", err)
|
||||
}
|
||||
|
||||
// create payload
|
||||
|
@ -171,7 +191,7 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byt
|
|||
payload.IdentitySig = signedPayload
|
||||
payloadEnc, err := proto.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling handshake payload: %s", err)
|
||||
return nil, fmt.Errorf("error marshaling handshake payload: %w", err)
|
||||
}
|
||||
return payloadEnc, nil
|
||||
}
|
||||
|
@ -183,7 +203,7 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati
|
|||
nhp := new(pb.NoiseHandshakePayload)
|
||||
err := proto.Unmarshal(payload, nhp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling remote handshake payload: %s", err)
|
||||
return fmt.Errorf("error unmarshaling remote handshake payload: %w", err)
|
||||
}
|
||||
|
||||
// unpack remote peer's public libp2p key
|
||||
|
@ -198,15 +218,16 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati
|
|||
|
||||
// if we know who we're trying to reach, make sure we have the right peer
|
||||
if s.initiator && s.remoteID != id {
|
||||
return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID, id)
|
||||
// use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms.
|
||||
return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
|
||||
}
|
||||
|
||||
// verify payload is signed by libp2p key
|
||||
// verify payload is signed by asserted remote libp2p key.
|
||||
sig := nhp.GetIdentitySig()
|
||||
msg := append([]byte(payloadSigPrefix), remoteStatic...)
|
||||
ok, err := remotePubKey.Verify(msg, sig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error verifying signature: %s", err)
|
||||
return fmt.Errorf("error verifying signature: %w", err)
|
||||
} else if !ok {
|
||||
return fmt.Errorf("handshake signature invalid")
|
||||
}
|
||||
|
|
|
@ -4,17 +4,18 @@ import (
|
|||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
mrand "math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p"
|
||||
"github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/host"
|
||||
net "github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
mrand "math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const testProtocolID = "/test/noise/integration"
|
||||
|
@ -65,22 +66,38 @@ func makeNode(t *testing.T, seed int64, port int) (host.Host, error) {
|
|||
func TestLibp2pIntegration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ha, err := makeNode(t, 1, 33333)
|
||||
ha, err := makeNode(t, 1, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer ha.Close()
|
||||
|
||||
hb, err := makeNode(t, 2, 34343)
|
||||
hb, err := makeNode(t, 2, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer hb.Close()
|
||||
|
||||
ha.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||
hb.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||
doneCh := make(chan struct{})
|
||||
|
||||
// hb reads.
|
||||
hb.SetStreamHandler(testProtocolID, func(stream net.Stream) {
|
||||
defer func() {
|
||||
if err := stream.Close(); err != nil {
|
||||
t.Error("error closing stream: ", err)
|
||||
}
|
||||
close(doneCh)
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
c, err := io.Copy(ioutil.Discard, stream)
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
t.Error("error reading from stream: ", err)
|
||||
return
|
||||
}
|
||||
t.Logf("read %d bytes in %dms", c, elapsed.Milliseconds())
|
||||
})
|
||||
|
||||
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID()))
|
||||
if err != nil {
|
||||
|
@ -99,6 +116,7 @@ func TestLibp2pIntegration(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// ha writes.
|
||||
stream, err := ha.NewStream(ctx, hb.ID(), testProtocolID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -109,9 +127,8 @@ func TestLibp2pIntegration(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
<-doneCh
|
||||
fmt.Println("fin")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
func writeRandomPayloadAndClose(t *testing.T, stream net.Stream) error {
|
||||
|
@ -123,29 +140,9 @@ func writeRandomPayloadAndClose(t *testing.T, stream net.Stream) error {
|
|||
|
||||
c, err := io.Copy(stream, lr)
|
||||
elapsed := time.Since(start)
|
||||
t.Logf("wrote %d bytes in %dms", c, elapsed.Milliseconds())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write out bytes: %v", err)
|
||||
}
|
||||
t.Logf("wrote %d bytes in %dms", c, elapsed.Milliseconds())
|
||||
return stream.Close()
|
||||
}
|
||||
|
||||
func streamHandler(t *testing.T) func(net.Stream) {
|
||||
return func(stream net.Stream) {
|
||||
t.Helper()
|
||||
defer func() {
|
||||
if err := stream.Close(); err != nil {
|
||||
t.Error("error closing stream: ", err)
|
||||
}
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
c, err := io.Copy(ioutil.Discard, stream)
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
t.Error("error reading from stream: ", err)
|
||||
return
|
||||
}
|
||||
t.Logf("read %d bytes in %dms", c, elapsed.Milliseconds())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,95 +4,112 @@ import (
|
|||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
pool "github.com/libp2p/go-buffer-pool"
|
||||
"golang.org/x/crypto/poly1305"
|
||||
)
|
||||
|
||||
// Each encrypted transport message must be <= 65,535 bytes, including 16
|
||||
// bytes of authentication data. To write larger plaintexts, we split them
|
||||
// into fragments of maxPlaintextLength before encrypting.
|
||||
const maxPlaintextLength = 65519
|
||||
// MaxTransportMsgLength is the Noise-imposed maximum transport message length,
|
||||
// inclusive of the MAC size (16 bytes, Poly1305 for noise-libp2p).
|
||||
const MaxTransportMsgLength = 0xffff
|
||||
|
||||
// Read reads from the secure connection, filling `buf` with plaintext data.
|
||||
// May read less than len(buf) if data is available immediately.
|
||||
// MaxPlaintextLength is the maximum payload size. It is MaxTransportMsgLength
|
||||
// minus the MAC size. Payloads over this size will be automatically chunked.
|
||||
const MaxPlaintextLength = MaxTransportMsgLength - poly1305.TagSize
|
||||
|
||||
// LengthPrefixLength is the length of the length prefix itself, which precedes
|
||||
// all transport messages in order to delimit them. In bytes.
|
||||
const LengthPrefixLength = 2
|
||||
|
||||
// Read reads from the secure connection, returning plaintext data in `buf`.
|
||||
//
|
||||
// Honours io.Reader in terms of behaviour.
|
||||
func (s *secureSession) Read(buf []byte) (int, error) {
|
||||
s.readLock.Lock()
|
||||
defer s.readLock.Unlock()
|
||||
|
||||
l := len(buf)
|
||||
|
||||
// if we have previously unread bytes, and they fit into the buf, copy them over and return
|
||||
if l <= len(s.msgBuffer) {
|
||||
copy(buf, s.msgBuffer)
|
||||
s.msgBuffer = s.msgBuffer[l:]
|
||||
return l, nil
|
||||
// 1. If we have queued received bytes:
|
||||
// 1a. If len(buf) < len(queued), saturate buf, update seek pointer, return.
|
||||
// 1b. If len(buf) >= len(queued), copy remaining to buf, release queued buffer back into pool, return.
|
||||
//
|
||||
// 2. Else, read the next message off the wire; next_len is length prefix.
|
||||
// 2a. If len(buf) >= next_len, copy the message to input buffer (zero-alloc path), and return.
|
||||
// 2b. If len(buf) < next_len, obtain buffer from pool, copy entire message into it, saturate buf, update seek pointer.
|
||||
if s.qbuf != nil {
|
||||
// we have queued bytes; copy as much as we can.
|
||||
copied := copy(buf, s.qbuf[s.qseek:])
|
||||
s.qseek += copied
|
||||
if s.qseek == len(s.qbuf) {
|
||||
// queued buffer is now empty, reset and release.
|
||||
pool.Put(s.qbuf)
|
||||
s.qseek, s.qbuf = 0, nil
|
||||
}
|
||||
return copied, nil
|
||||
}
|
||||
|
||||
readChunk := func(buf []byte) (int, error) {
|
||||
ciphertext, err := s.readMsgInsecure()
|
||||
// cbuf is the ciphertext buffer.
|
||||
cbuf, err := s.readMsgInsecure()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
plaintext, err := s.decrypt(ciphertext)
|
||||
if err != nil {
|
||||
// plen is the payload length: the transport message size minus the authentication tag.
|
||||
// if the reader is willing to read at least as many bytes as we are receiving,
|
||||
// decrypt the message directly into the buffer (zero-alloc path).
|
||||
if plen := len(cbuf) - poly1305.TagSize; len(buf) >= plen {
|
||||
defer pool.Put(cbuf)
|
||||
if _, err := s.decrypt(buf[:0], cbuf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return plen, nil
|
||||
}
|
||||
|
||||
// otherwise, get a buffer from the pool so we can stash the payload.
|
||||
// we decrypt in place, since we're retaining cbuf (or a vew thereof).
|
||||
if s.qbuf, err = s.decrypt(cbuf[:0], cbuf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// append plaintext to message buffer, copy over what can fit in the buf
|
||||
// then advance message buffer to remove what was copied
|
||||
s.msgBuffer = append(s.msgBuffer, plaintext...)
|
||||
c := copy(buf, s.msgBuffer)
|
||||
s.msgBuffer = s.msgBuffer[c:]
|
||||
return c, nil
|
||||
}
|
||||
|
||||
total := 0
|
||||
for i := 0; i < len(buf); i += maxPlaintextLength {
|
||||
end := i + maxPlaintextLength
|
||||
if end > len(buf) {
|
||||
end = len(buf)
|
||||
}
|
||||
|
||||
c, err := readChunk(buf[i:end])
|
||||
total += c
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
|
||||
return total, nil
|
||||
// copy as many bytes as we can; update seek pointer.
|
||||
s.qseek = copy(buf, s.qbuf)
|
||||
return s.qseek, nil
|
||||
}
|
||||
|
||||
// Write encrypts the plaintext `in` data and sends it on the
|
||||
// secure connection.
|
||||
func (s *secureSession) Write(in []byte) (int, error) {
|
||||
func (s *secureSession) Write(data []byte) (int, error) {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
|
||||
writeChunk := func(in []byte) (int, error) {
|
||||
ciphertext, err := s.encrypt(in)
|
||||
var (
|
||||
written int
|
||||
cbuf []byte
|
||||
total = len(data)
|
||||
)
|
||||
|
||||
if total < MaxPlaintextLength {
|
||||
cbuf = pool.Get(total + poly1305.TagSize)
|
||||
} else {
|
||||
cbuf = pool.Get(MaxTransportMsgLength)
|
||||
}
|
||||
defer pool.Put(cbuf)
|
||||
|
||||
for written < total {
|
||||
end := written + MaxPlaintextLength
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
|
||||
b, err := s.encrypt(cbuf[:0], data[written:end])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = s.writeMsgInsecure(ciphertext)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(in), err
|
||||
}
|
||||
|
||||
written := 0
|
||||
for i := 0; i < len(in); i += maxPlaintextLength {
|
||||
end := i + maxPlaintextLength
|
||||
if end > len(in) {
|
||||
end = len(in)
|
||||
}
|
||||
|
||||
l, err := writeChunk(in[i:end])
|
||||
written += l
|
||||
_, err = s.writeMsgInsecure(b)
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
written = end
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
@ -101,26 +118,31 @@ func (s *secureSession) Write(in []byte) (int, error) {
|
|||
// it first reads the message length, then consumes that many bytes
|
||||
// from the insecure conn.
|
||||
func (s *secureSession) readMsgInsecure() ([]byte, error) {
|
||||
buf := make([]byte, 2)
|
||||
_, err := io.ReadFull(s.insecure, buf)
|
||||
_, err := io.ReadFull(s.insecure, s.rlen[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
size := int(binary.BigEndian.Uint16(buf))
|
||||
buf = make([]byte, size)
|
||||
size := int(binary.BigEndian.Uint16(s.rlen[:]))
|
||||
buf := pool.Get(size)
|
||||
_, err = io.ReadFull(s.insecure, buf)
|
||||
return buf, err
|
||||
}
|
||||
|
||||
// writeMsgInsecure writes to the insecure conn.
|
||||
// data will be prefixed with its length in bytes, written as a 16-bit uint in network order.
|
||||
func (s *secureSession) writeMsgInsecure(data []byte) error {
|
||||
buf := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(data)))
|
||||
_, err := s.insecure.Write(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing length prefix: %s", err)
|
||||
func (s *secureSession) writeMsgInsecure(data []byte) (int, error) {
|
||||
// we rather stage the length-prefixed write in a buffer to then call Write
|
||||
// on the underlying transport at once, rather than Write twice and likely
|
||||
// induce transport-level fragmentation.
|
||||
l := len(data)
|
||||
buf := pool.Get(LengthPrefixLength + l)
|
||||
defer pool.Put(buf)
|
||||
|
||||
// length-prefix || data
|
||||
binary.BigEndian.PutUint16(buf, uint16(l))
|
||||
n := copy(buf[LengthPrefixLength:], data)
|
||||
if n != l {
|
||||
return 0, fmt.Errorf("assertion failed during noise secure channel write; expected to copy %d bytes, copied: %d", l, n)
|
||||
}
|
||||
_, err = s.insecure.Write(data)
|
||||
return err
|
||||
return s.insecure.Write(buf)
|
||||
}
|
||||
|
|
|
@ -20,17 +20,20 @@ type secureSession struct {
|
|||
remoteID peer.ID
|
||||
remoteKey crypto.PubKey
|
||||
|
||||
insecure net.Conn
|
||||
msgBuffer []byte
|
||||
readLock sync.Mutex
|
||||
writeLock sync.Mutex
|
||||
insecure net.Conn
|
||||
|
||||
qseek int // queued bytes seek value.
|
||||
qbuf []byte // queued bytes buffer.
|
||||
rlen [2]byte // work buffer to read in the incoming message length.
|
||||
|
||||
enc *noise.CipherState
|
||||
dec *noise.CipherState
|
||||
}
|
||||
|
||||
// newSecureSession creates a noise session over the given insecure Conn, using the
|
||||
// libp2p identity keypair from the given Transport.
|
||||
// newSecureSession creates a Noise session over the given insecure Conn, using
|
||||
// the libp2p identity keypair from the given Transport.
|
||||
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, initiator bool) (*secureSession, error) {
|
||||
s := &secureSession{
|
||||
insecure: insecure,
|
||||
|
|
|
@ -35,12 +35,12 @@ func New(privkey crypto.PrivKey) (*Transport, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// SecureInbound runs noise handshake as the responder
|
||||
// SecureInbound runs the Noise handshake as the responder.
|
||||
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, error) {
|
||||
return newSecureSession(t, ctx, insecure, "", false)
|
||||
}
|
||||
|
||||
// SecureOutbound runs noise handshake as the initiator
|
||||
// SecureOutbound runs the Noise handshake as the initiator.
|
||||
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
|
||||
return newSecureSession(t, ctx, insecure, p, true)
|
||||
}
|
||||
|
|
|
@ -3,9 +3,12 @@ package noise
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
crypto "github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
|
@ -86,6 +89,28 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess
|
|||
return initConn.(*secureSession), respConn.(*secureSession)
|
||||
}
|
||||
|
||||
func TestDeadlines(t *testing.T) {
|
||||
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
|
||||
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
|
||||
|
||||
init, resp := newConnPair(t)
|
||||
defer init.Close()
|
||||
defer resp.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := initTransport.SecureOutbound(ctx, init, respTransport.localID)
|
||||
if err == nil {
|
||||
t.Fatalf("expected i/o timeout err; got: %s", err)
|
||||
}
|
||||
|
||||
var neterr net.Error
|
||||
if ok := errors.As(err, &neterr); !ok || !neterr.Timeout() {
|
||||
t.Fatalf("expected i/o timeout err; got: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIDs(t *testing.T) {
|
||||
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
|
||||
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
|
||||
|
@ -177,7 +202,7 @@ func TestLargePayloads(t *testing.T) {
|
|||
}
|
||||
|
||||
after := make([]byte, len(before))
|
||||
afterLen, err := respConn.Read(after)
|
||||
afterLen, err := io.ReadFull(respConn, after)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue