Merge pull request #58 from libp2p/raul-review

This commit is contained in:
Raúl Kripalani 2020-04-24 20:54:23 +01:00 committed by GitHub
commit 9b9c06f042
10 changed files with 252 additions and 153 deletions

View File

@ -124,6 +124,10 @@ func pipeRandom(src rand.Source, w io.WriteCloser, r io.Reader, n int64) error {
func benchDataTransfer(b *benchenv, size int64) {
var totalBytes int64
var totalTime time.Duration
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
initSession, respSession := b.connect(true)
@ -153,6 +157,9 @@ func BenchmarkTransfer500Mb(b *testing.B) {
}
func (b benchenv) benchHandshake() {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
i, r := b.connect(false)
b.StopTimer()

View File

@ -1,22 +1,46 @@
package noise
import "errors"
import (
"errors"
)
func (s *secureSession) encrypt(plaintext []byte) (ciphertext []byte, err error) {
// encrypt calls the cipher's encryption. It encrypts the provided plaintext,
// slice-appending the ciphertext on out.
//
// Usually you want to pass a 0-len slice to this method, with enough capacity
// to accommodate the ciphertext in order to spare allocs.
//
// encrypt returns a new slice header, whose len is the length of the resulting
// ciphertext, including the authentication tag.
//
// This method will not allocate if the supplied slice is large enough to
// accommodate the encrypted data + authentication tag. If so, the returned
// slice header should be a view of the original slice.
//
// With the poly1305 MAC function that noise-libp2p uses, the authentication tag
// adds an overhead of 16 bytes.
func (s *secureSession) encrypt(out, plaintext []byte) ([]byte, error) {
if s.enc == nil {
return nil, errors.New("cannot encrypt, handshake incomplete")
}
// TODO: use pre-allocated buffers
ciphertext = s.enc.Encrypt(nil, nil, plaintext)
return ciphertext, nil
return s.enc.Encrypt(out, nil, plaintext), nil
}
func (s *secureSession) decrypt(ciphertext []byte) (plaintext []byte, err error) {
// decrypt calls the cipher's decryption. It decrypts the provided ciphertext,
// slice-appending the plaintext on out.
//
// Usually you want to pass a 0-len slice to this method, with enough capacity
// to accommodate the plaintext in order to spare allocs.
//
// decrypt returns a new slice header, whose len is the length of the resulting
// plaintext, without the authentication tag.
//
// This method will not allocate if the supplied slice is large enough to
// accommodate the plaintext. If so, the returned slice header should be a view
// of the original slice.
func (s *secureSession) decrypt(out, ciphertext []byte) ([]byte, error) {
if s.dec == nil {
return nil, errors.New("cannot decrypt, handshake incomplete")
}
// TODO: use pre-allocated buffers
return s.dec.Decrypt(nil, nil, ciphertext)
return s.dec.Decrypt(out, nil, ciphertext)
}

View File

@ -18,12 +18,12 @@ func TestEncryptAndDecrypt_InitToResp(t *testing.T) {
defer respConn.Close()
plaintext := []byte("helloworld")
ciphertext, err := initConn.encrypt(plaintext)
ciphertext, err := initConn.encrypt(nil, plaintext)
if err != nil {
t.Fatal(err)
}
result, err := respConn.decrypt(ciphertext)
result, err := respConn.decrypt(nil, ciphertext)
if !bytes.Equal(plaintext, result) {
t.Fatalf("got %x expected %x", result, plaintext)
} else if err != nil {
@ -31,12 +31,12 @@ func TestEncryptAndDecrypt_InitToResp(t *testing.T) {
}
plaintext = []byte("goodbye")
ciphertext, err = initConn.encrypt(plaintext)
ciphertext, err = initConn.encrypt(nil, plaintext)
if err != nil {
t.Fatal(err)
}
result, err = respConn.decrypt(ciphertext)
result, err = respConn.decrypt(nil, ciphertext)
if !bytes.Equal(plaintext, result) {
t.Fatalf("got %x expected %x", result, plaintext)
} else if err != nil {
@ -53,12 +53,12 @@ func TestEncryptAndDecrypt_RespToInit(t *testing.T) {
defer respConn.Close()
plaintext := []byte("helloworld")
ciphertext, err := respConn.encrypt(plaintext)
ciphertext, err := respConn.encrypt(nil, plaintext)
if err != nil {
t.Fatal(err)
}
result, err := initConn.decrypt(ciphertext)
result, err := initConn.decrypt(nil, ciphertext)
if !bytes.Equal(plaintext, result) {
t.Fatalf("got %x expected %x", result, plaintext)
} else if err != nil {
@ -75,14 +75,14 @@ func TestCryptoFailsIfCiphertextIsAltered(t *testing.T) {
defer respConn.Close()
plaintext := []byte("helloworld")
ciphertext, err := respConn.encrypt(plaintext)
ciphertext, err := respConn.encrypt(nil, plaintext)
if err != nil {
t.Fatal(err)
}
ciphertext[0] = ^ciphertext[0]
_, err = initConn.decrypt(ciphertext)
_, err = initConn.decrypt(nil, ciphertext)
if err == nil {
t.Fatal("expected decryption to fail when ciphertext altered")
}
@ -94,11 +94,11 @@ func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) {
_ = resp.Close()
session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", true)
_, err := session.encrypt([]byte("hi"))
_, err := session.encrypt(nil, []byte("hi"))
if err == nil {
t.Error("expected encryption error when handshake incomplete")
}
_, err = session.decrypt([]byte("it's a secret"))
_, err = session.decrypt(nil, []byte("it's a secret"))
if err == nil {
t.Error("expected decryption error when handshake incomplete")
}

View File

@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"fmt"
"time"
"github.com/flynn/noise"
"github.com/gogo/protobuf/proto"
@ -26,7 +27,7 @@ var cipherSuite = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, no
func (s *secureSession) runHandshake(ctx context.Context) error {
kp, err := noise.DH25519.GenerateKeypair(rand.Reader)
if err != nil {
return fmt.Errorf("error generating static keypair: %s", err)
return fmt.Errorf("error generating static keypair: %w", err)
}
cfg := noise.Config{
@ -38,7 +39,7 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
hs, err := noise.NewHandshakeState(cfg)
if err != nil {
return fmt.Errorf("error initializing handshake state: %s", err)
return fmt.Errorf("error initializing handshake state: %w", err)
}
payload, err := s.generateHandshakePayload(kp)
@ -46,17 +47,30 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
return err
}
// set a deadline to complete the handshake, if one has been supplied.
// clear it after we're done.
if deadline, ok := ctx.Deadline(); ok {
if err := s.SetDeadline(deadline); err == nil {
// schedule the deadline removal once we're done handshaking.
defer s.SetDeadline(time.Time{})
}
// TODO: else case (transport doesn't support native timeouts); spin off
// a goroutine to monitor the context cancellation and pull the rug
// from under by closing the connection altogether.
}
if s.initiator {
// stage 0 //
// do not send the payload just yet, as it would be plaintext; not secret.
err = s.sendHandshakeMessage(hs, nil)
if err != nil {
return fmt.Errorf("error sending handshake message: %s", err)
return fmt.Errorf("error sending handshake message: %w", err)
}
// stage 1 //
plaintext, err := s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %s", err)
return fmt.Errorf("error reading handshake message: %w", err)
}
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
@ -66,25 +80,25 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
// stage 2 //
err = s.sendHandshakeMessage(hs, payload)
if err != nil {
return fmt.Errorf("error sending handshake message: %s", err)
return fmt.Errorf("error sending handshake message: %w", err)
}
} else {
// stage 0 //
plaintext, err := s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %s", err)
return fmt.Errorf("error reading handshake message: %w", err)
}
// stage 1 //
err = s.sendHandshakeMessage(hs, payload)
if err != nil {
return fmt.Errorf("error sending handshake message: %s", err)
return fmt.Errorf("error sending handshake message: %w", err)
}
// stage 2 //
plaintext, err = s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %s", err)
return fmt.Errorf("error reading handshake message: %w", err)
}
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
@ -95,9 +109,11 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
return nil
}
// setCipherStates is called when the final handshake message is processed by
// setCipherStates sets the initial cipher states that will be used to protect
// traffic after the handshake.
//
// It is called when the final handshake message is processed by
// either sendHandshakeMessage or readHandshakeMessage.
// It sets the initial cipher states that will be used to protect traffic after the handshake.
func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) {
if s.initiator {
s.enc = cs1
@ -109,6 +125,7 @@ func (s *secureSession) setCipherStates(cs1, cs2 *noise.CipherState) {
}
// sendHandshakeMessage sends the next handshake message in the sequence.
//
// If payload is non-empty, it will be included in the handshake message.
// If this is the final message in the sequence, calls setCipherStates
// to initialize cipher states.
@ -118,7 +135,7 @@ func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload [
return err
}
err = s.writeMsgInsecure(buf)
_, err = s.writeMsgInsecure(buf)
if err != nil {
return err
}
@ -131,8 +148,10 @@ func (s *secureSession) sendHandshakeMessage(hs *noise.HandshakeState, payload [
// readHandshakeMessage reads a message from the insecure conn and tries to
// process it as the expected next message in the handshake sequence.
//
// If the message contains a payload, it will be decrypted and returned.
// If this is the final message in the sequence, calls setCipherStates
//
// If this is the final message in the sequence, it calls setCipherStates
// to initialize cipher states.
func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte, error) {
raw, err := s.readMsgInsecure()
@ -152,17 +171,18 @@ func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte,
// generateHandshakePayload creates a libp2p handshake payload with a
// signature of our static noise key.
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byte, error) {
// setup libp2p keys
// obtain the public key from the handshake session so we can sign it with
// our libp2p secret key.
localKeyRaw, err := s.LocalPublicKey().Bytes()
if err != nil {
return nil, fmt.Errorf("error serializing libp2p identity key: %s", err)
return nil, fmt.Errorf("error serializing libp2p identity key: %w", err)
}
// sign noise data for payload
// prepare payload to sign; perform signature.
toSign := append([]byte(payloadSigPrefix), localStatic.Public...)
signedPayload, err := s.localKey.Sign(toSign)
if err != nil {
return nil, fmt.Errorf("error sigining handshake payload: %s", err)
return nil, fmt.Errorf("error sigining handshake payload: %w", err)
}
// create payload
@ -171,7 +191,7 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byt
payload.IdentitySig = signedPayload
payloadEnc, err := proto.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling handshake payload: %s", err)
return nil, fmt.Errorf("error marshaling handshake payload: %w", err)
}
return payloadEnc, nil
}
@ -183,7 +203,7 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati
nhp := new(pb.NoiseHandshakePayload)
err := proto.Unmarshal(payload, nhp)
if err != nil {
return fmt.Errorf("error unmarshaling remote handshake payload: %s", err)
return fmt.Errorf("error unmarshaling remote handshake payload: %w", err)
}
// unpack remote peer's public libp2p key
@ -198,15 +218,16 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati
// if we know who we're trying to reach, make sure we have the right peer
if s.initiator && s.remoteID != id {
return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID, id)
// use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms.
return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
}
// verify payload is signed by libp2p key
// verify payload is signed by asserted remote libp2p key.
sig := nhp.GetIdentitySig()
msg := append([]byte(payloadSigPrefix), remoteStatic...)
ok, err := remotePubKey.Verify(msg, sig)
if err != nil {
return fmt.Errorf("error verifying signature: %s", err)
return fmt.Errorf("error verifying signature: %w", err)
} else if !ok {
return fmt.Errorf("handshake signature invalid")
}

View File

@ -4,17 +4,18 @@ import (
"context"
"crypto/rand"
"fmt"
"io"
"io/ioutil"
mrand "math/rand"
"testing"
"time"
"github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/host"
net "github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
ma "github.com/multiformats/go-multiaddr"
"io"
"io/ioutil"
mrand "math/rand"
"testing"
"time"
)
const testProtocolID = "/test/noise/integration"
@ -65,22 +66,38 @@ func makeNode(t *testing.T, seed int64, port int) (host.Host, error) {
func TestLibp2pIntegration(t *testing.T) {
ctx := context.Background()
ha, err := makeNode(t, 1, 33333)
ha, err := makeNode(t, 1, 0)
if err != nil {
t.Fatal(err)
}
defer ha.Close()
hb, err := makeNode(t, 2, 34343)
hb, err := makeNode(t, 2, 0)
if err != nil {
t.Fatal(err)
}
defer hb.Close()
ha.SetStreamHandler(testProtocolID, streamHandler(t))
hb.SetStreamHandler(testProtocolID, streamHandler(t))
doneCh := make(chan struct{})
// hb reads.
hb.SetStreamHandler(testProtocolID, func(stream net.Stream) {
defer func() {
if err := stream.Close(); err != nil {
t.Error("error closing stream: ", err)
}
close(doneCh)
}()
start := time.Now()
c, err := io.Copy(ioutil.Discard, stream)
elapsed := time.Since(start)
if err != nil {
t.Error("error reading from stream: ", err)
return
}
t.Logf("read %d bytes in %dms", c, elapsed.Milliseconds())
})
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID()))
if err != nil {
@ -99,6 +116,7 @@ func TestLibp2pIntegration(t *testing.T) {
t.Fatal(err)
}
// ha writes.
stream, err := ha.NewStream(ctx, hb.ID(), testProtocolID)
if err != nil {
t.Fatal(err)
@ -109,9 +127,8 @@ func TestLibp2pIntegration(t *testing.T) {
t.Fatal(err)
}
<-doneCh
fmt.Println("fin")
time.Sleep(time.Second)
}
func writeRandomPayloadAndClose(t *testing.T, stream net.Stream) error {
@ -123,29 +140,9 @@ func writeRandomPayloadAndClose(t *testing.T, stream net.Stream) error {
c, err := io.Copy(stream, lr)
elapsed := time.Since(start)
t.Logf("wrote %d bytes in %dms", c, elapsed.Milliseconds())
if err != nil {
return fmt.Errorf("failed to write out bytes: %v", err)
}
t.Logf("wrote %d bytes in %dms", c, elapsed.Milliseconds())
return stream.Close()
}
func streamHandler(t *testing.T) func(net.Stream) {
return func(stream net.Stream) {
t.Helper()
defer func() {
if err := stream.Close(); err != nil {
t.Error("error closing stream: ", err)
}
}()
start := time.Now()
c, err := io.Copy(ioutil.Discard, stream)
elapsed := time.Since(start)
if err != nil {
t.Error("error reading from stream: ", err)
return
}
t.Logf("read %d bytes in %dms", c, elapsed.Milliseconds())
}
}

View File

@ -4,95 +4,112 @@ import (
"encoding/binary"
"fmt"
"io"
pool "github.com/libp2p/go-buffer-pool"
"golang.org/x/crypto/poly1305"
)
// Each encrypted transport message must be <= 65,535 bytes, including 16
// bytes of authentication data. To write larger plaintexts, we split them
// into fragments of maxPlaintextLength before encrypting.
const maxPlaintextLength = 65519
// MaxTransportMsgLength is the Noise-imposed maximum transport message length,
// inclusive of the MAC size (16 bytes, Poly1305 for noise-libp2p).
const MaxTransportMsgLength = 0xffff
// Read reads from the secure connection, filling `buf` with plaintext data.
// May read less than len(buf) if data is available immediately.
// MaxPlaintextLength is the maximum payload size. It is MaxTransportMsgLength
// minus the MAC size. Payloads over this size will be automatically chunked.
const MaxPlaintextLength = MaxTransportMsgLength - poly1305.TagSize
// LengthPrefixLength is the length of the length prefix itself, which precedes
// all transport messages in order to delimit them. In bytes.
const LengthPrefixLength = 2
// Read reads from the secure connection, returning plaintext data in `buf`.
//
// Honours io.Reader in terms of behaviour.
func (s *secureSession) Read(buf []byte) (int, error) {
s.readLock.Lock()
defer s.readLock.Unlock()
l := len(buf)
// if we have previously unread bytes, and they fit into the buf, copy them over and return
if l <= len(s.msgBuffer) {
copy(buf, s.msgBuffer)
s.msgBuffer = s.msgBuffer[l:]
return l, nil
// 1. If we have queued received bytes:
// 1a. If len(buf) < len(queued), saturate buf, update seek pointer, return.
// 1b. If len(buf) >= len(queued), copy remaining to buf, release queued buffer back into pool, return.
//
// 2. Else, read the next message off the wire; next_len is length prefix.
// 2a. If len(buf) >= next_len, copy the message to input buffer (zero-alloc path), and return.
// 2b. If len(buf) < next_len, obtain buffer from pool, copy entire message into it, saturate buf, update seek pointer.
if s.qbuf != nil {
// we have queued bytes; copy as much as we can.
copied := copy(buf, s.qbuf[s.qseek:])
s.qseek += copied
if s.qseek == len(s.qbuf) {
// queued buffer is now empty, reset and release.
pool.Put(s.qbuf)
s.qseek, s.qbuf = 0, nil
}
return copied, nil
}
readChunk := func(buf []byte) (int, error) {
ciphertext, err := s.readMsgInsecure()
if err != nil {
// cbuf is the ciphertext buffer.
cbuf, err := s.readMsgInsecure()
if err != nil {
return 0, err
}
// plen is the payload length: the transport message size minus the authentication tag.
// if the reader is willing to read at least as many bytes as we are receiving,
// decrypt the message directly into the buffer (zero-alloc path).
if plen := len(cbuf) - poly1305.TagSize; len(buf) >= plen {
defer pool.Put(cbuf)
if _, err := s.decrypt(buf[:0], cbuf); err != nil {
return 0, err
}
plaintext, err := s.decrypt(ciphertext)
if err != nil {
return 0, err
}
// append plaintext to message buffer, copy over what can fit in the buf
// then advance message buffer to remove what was copied
s.msgBuffer = append(s.msgBuffer, plaintext...)
c := copy(buf, s.msgBuffer)
s.msgBuffer = s.msgBuffer[c:]
return c, nil
return plen, nil
}
total := 0
for i := 0; i < len(buf); i += maxPlaintextLength {
end := i + maxPlaintextLength
if end > len(buf) {
end = len(buf)
}
c, err := readChunk(buf[i:end])
total += c
if err != nil {
return total, err
}
// otherwise, get a buffer from the pool so we can stash the payload.
// we decrypt in place, since we're retaining cbuf (or a vew thereof).
if s.qbuf, err = s.decrypt(cbuf[:0], cbuf); err != nil {
return 0, err
}
return total, nil
// copy as many bytes as we can; update seek pointer.
s.qseek = copy(buf, s.qbuf)
return s.qseek, nil
}
// Write encrypts the plaintext `in` data and sends it on the
// secure connection.
func (s *secureSession) Write(in []byte) (int, error) {
func (s *secureSession) Write(data []byte) (int, error) {
s.writeLock.Lock()
defer s.writeLock.Unlock()
writeChunk := func(in []byte) (int, error) {
ciphertext, err := s.encrypt(in)
if err != nil {
return 0, err
}
var (
written int
cbuf []byte
total = len(data)
)
err = s.writeMsgInsecure(ciphertext)
if err != nil {
return 0, err
}
return len(in), err
if total < MaxPlaintextLength {
cbuf = pool.Get(total + poly1305.TagSize)
} else {
cbuf = pool.Get(MaxTransportMsgLength)
}
defer pool.Put(cbuf)
written := 0
for i := 0; i < len(in); i += maxPlaintextLength {
end := i + maxPlaintextLength
if end > len(in) {
end = len(in)
for written < total {
end := written + MaxPlaintextLength
if end > total {
end = total
}
l, err := writeChunk(in[i:end])
written += l
b, err := s.encrypt(cbuf[:0], data[written:end])
if err != nil {
return 0, err
}
_, err = s.writeMsgInsecure(b)
if err != nil {
return written, err
}
written = end
}
return written, nil
}
@ -101,26 +118,31 @@ func (s *secureSession) Write(in []byte) (int, error) {
// it first reads the message length, then consumes that many bytes
// from the insecure conn.
func (s *secureSession) readMsgInsecure() ([]byte, error) {
buf := make([]byte, 2)
_, err := io.ReadFull(s.insecure, buf)
_, err := io.ReadFull(s.insecure, s.rlen[:])
if err != nil {
return nil, err
}
size := int(binary.BigEndian.Uint16(buf))
buf = make([]byte, size)
size := int(binary.BigEndian.Uint16(s.rlen[:]))
buf := pool.Get(size)
_, err = io.ReadFull(s.insecure, buf)
return buf, err
}
// writeMsgInsecure writes to the insecure conn.
// data will be prefixed with its length in bytes, written as a 16-bit uint in network order.
func (s *secureSession) writeMsgInsecure(data []byte) error {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(len(data)))
_, err := s.insecure.Write(buf)
if err != nil {
return fmt.Errorf("error writing length prefix: %s", err)
func (s *secureSession) writeMsgInsecure(data []byte) (int, error) {
// we rather stage the length-prefixed write in a buffer to then call Write
// on the underlying transport at once, rather than Write twice and likely
// induce transport-level fragmentation.
l := len(data)
buf := pool.Get(LengthPrefixLength + l)
defer pool.Put(buf)
// length-prefix || data
binary.BigEndian.PutUint16(buf, uint16(l))
n := copy(buf[LengthPrefixLength:], data)
if n != l {
return 0, fmt.Errorf("assertion failed during noise secure channel write; expected to copy %d bytes, copied: %d", l, n)
}
_, err = s.insecure.Write(data)
return err
return s.insecure.Write(buf)
}

View File

@ -20,17 +20,20 @@ type secureSession struct {
remoteID peer.ID
remoteKey crypto.PubKey
insecure net.Conn
msgBuffer []byte
readLock sync.Mutex
writeLock sync.Mutex
insecure net.Conn
qseek int // queued bytes seek value.
qbuf []byte // queued bytes buffer.
rlen [2]byte // work buffer to read in the incoming message length.
enc *noise.CipherState
dec *noise.CipherState
}
// newSecureSession creates a noise session over the given insecure Conn, using the
// libp2p identity keypair from the given Transport.
// newSecureSession creates a Noise session over the given insecure Conn, using
// the libp2p identity keypair from the given Transport.
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, initiator bool) (*secureSession, error) {
s := &secureSession{
insecure: insecure,

View File

@ -35,12 +35,12 @@ func New(privkey crypto.PrivKey) (*Transport, error) {
}, nil
}
// SecureInbound runs noise handshake as the responder
// SecureInbound runs the Noise handshake as the responder.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, "", false)
}
// SecureOutbound runs noise handshake as the initiator
// SecureOutbound runs the Noise handshake as the initiator.
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, p, true)
}

View File

@ -3,9 +3,12 @@ package noise
import (
"bytes"
"context"
"errors"
"io"
"math/rand"
"net"
"testing"
"time"
crypto "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
@ -86,6 +89,28 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess
return initConn.(*secureSession), respConn.(*secureSession)
}
func TestDeadlines(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
init, resp := newConnPair(t)
defer init.Close()
defer resp.Close()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := initTransport.SecureOutbound(ctx, init, respTransport.localID)
if err == nil {
t.Fatalf("expected i/o timeout err; got: %s", err)
}
var neterr net.Error
if ok := errors.As(err, &neterr); !ok || !neterr.Timeout() {
t.Fatalf("expected i/o timeout err; got: %s", err)
}
}
func TestIDs(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
@ -177,7 +202,7 @@ func TestLargePayloads(t *testing.T) {
}
after := make([]byte, len(before))
afterLen, err := respConn.Read(after)
afterLen, err := io.ReadFull(respConn, after)
if err != nil {
t.Fatal(err)
}