From 8c0011194e4efbc9d38b0b2c7f2d3685bc01e43a Mon Sep 17 00:00:00 2001 From: Yusef Napora Date: Tue, 28 Jan 2020 13:53:25 -0500 Subject: [PATCH 1/3] fix protocol id in integration test --- p2p/security/noise/integration_test.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/p2p/security/noise/integration_test.go b/p2p/security/noise/integration_test.go index 616ab3b5..279ad9d6 100644 --- a/p2p/security/noise/integration_test.go +++ b/p2p/security/noise/integration_test.go @@ -17,6 +17,8 @@ import ( "time" ) +const testProtocolID = "/test/noise/integration" + func generateKey(seed int64) (crypto.PrivKey, error) { var r io.Reader if seed == 0 { @@ -109,8 +111,8 @@ func TestLibp2pIntegration_NoPipes(t *testing.T) { defer hb.Close() - ha.SetStreamHandler(ID, handleStream) - hb.SetStreamHandler(ID, handleStream) + ha.SetStreamHandler(testProtocolID, handleStream) + hb.SetStreamHandler(testProtocolID, handleStream) addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID())) if err != nil { @@ -129,7 +131,7 @@ func TestLibp2pIntegration_NoPipes(t *testing.T) { t.Fatal(err) } - stream, err := ha.NewStream(ctx, hb.ID(), ID) + stream, err := ha.NewStream(ctx, hb.ID(), testProtocolID) if err != nil { t.Fatal(err) } @@ -166,8 +168,8 @@ func TestLibp2pIntegration_WithPipes(t *testing.T) { defer hb.Close() - ha.SetStreamHandler(ID, handleStream) - hb.SetStreamHandler(ID, handleStream) + ha.SetStreamHandler(testProtocolID, handleStream) + hb.SetStreamHandler(testProtocolID, handleStream) addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", ha.Addrs()[0].String(), ha.ID())) if err != nil { @@ -186,7 +188,7 @@ func TestLibp2pIntegration_WithPipes(t *testing.T) { t.Fatal(err) } - stream, err := hb.NewStream(ctx, ha.ID(), ID) + stream, err := hb.NewStream(ctx, ha.ID(), testProtocolID) if err != nil { t.Fatal(err) } @@ -223,8 +225,8 @@ func TestLibp2pIntegration_XXFallback(t *testing.T) { defer hb.Close() - ha.SetStreamHandler(ID, handleStream) - hb.SetStreamHandler(ID, handleStream) + ha.SetStreamHandler(testProtocolID, handleStream) + hb.SetStreamHandler(testProtocolID, handleStream) addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID())) if err != nil { @@ -243,7 +245,7 @@ func TestLibp2pIntegration_XXFallback(t *testing.T) { t.Fatal(err) } - stream, err := hb.NewStream(ctx, ha.ID(), ID) + stream, err := hb.NewStream(ctx, ha.ID(), testProtocolID) if err != nil { t.Fatal(err) } From 297dd7fae97bbd2cf4c169b96114f5cca192b7a2 Mon Sep 17 00:00:00 2001 From: Yusef Napora Date: Wed, 29 Jan 2020 09:05:19 -0500 Subject: [PATCH 2/3] send much more data in integration test --- p2p/security/noise/integration_test.go | 62 +++++++++++++++++--------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/p2p/security/noise/integration_test.go b/p2p/security/noise/integration_test.go index 279ad9d6..0cb2a930 100644 --- a/p2p/security/noise/integration_test.go +++ b/p2p/security/noise/integration_test.go @@ -1,7 +1,6 @@ package noise import ( - "bufio" "context" "crypto/rand" "fmt" @@ -12,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p-core/peer" ma "github.com/multiformats/go-multiaddr" "io" + "io/ioutil" mrand "math/rand" "testing" "time" @@ -111,8 +111,8 @@ func TestLibp2pIntegration_NoPipes(t *testing.T) { defer hb.Close() - ha.SetStreamHandler(testProtocolID, handleStream) - hb.SetStreamHandler(testProtocolID, handleStream) + ha.SetStreamHandler(testProtocolID, streamHandler(t)) + hb.SetStreamHandler(testProtocolID, streamHandler(t)) addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID())) if err != nil { @@ -136,7 +136,7 @@ func TestLibp2pIntegration_NoPipes(t *testing.T) { t.Fatal(err) } - _, err = stream.Write([]byte("hello\n")) + err = writeRandomPayloadAndClose(t, stream) if err != nil { t.Fatal(err) } @@ -168,8 +168,8 @@ func TestLibp2pIntegration_WithPipes(t *testing.T) { defer hb.Close() - ha.SetStreamHandler(testProtocolID, handleStream) - hb.SetStreamHandler(testProtocolID, handleStream) + ha.SetStreamHandler(testProtocolID, streamHandler(t)) + hb.SetStreamHandler(testProtocolID, streamHandler(t)) addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", ha.Addrs()[0].String(), ha.ID())) if err != nil { @@ -193,7 +193,7 @@ func TestLibp2pIntegration_WithPipes(t *testing.T) { t.Fatal(err) } - _, err = stream.Write([]byte("hello\n")) + err = writeRandomPayloadAndClose(t, stream) if err != nil { t.Fatal(err) } @@ -225,8 +225,8 @@ func TestLibp2pIntegration_XXFallback(t *testing.T) { defer hb.Close() - ha.SetStreamHandler(testProtocolID, handleStream) - hb.SetStreamHandler(testProtocolID, handleStream) + ha.SetStreamHandler(testProtocolID, streamHandler(t)) + hb.SetStreamHandler(testProtocolID, streamHandler(t)) addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID())) if err != nil { @@ -250,7 +250,7 @@ func TestLibp2pIntegration_XXFallback(t *testing.T) { t.Fatal(err) } - _, err = stream.Write([]byte("hello\n")) + err = writeRandomPayloadAndClose(t, stream) if err != nil { t.Fatal(err) } @@ -277,18 +277,38 @@ func TestConstrucingWithMaker(t *testing.T) { _ = h.Close() } -func handleStream(stream net.Stream) { - defer func() { - if err := stream.Close(); err != nil { - log.Error("error closing stream", "err", err) - } - }() +func writeRandomPayloadAndClose(t *testing.T, stream net.Stream) error { + t.Helper() + size := 1 << 24 + r := mrand.New(mrand.NewSource(42)) + start := time.Now() + lr := io.LimitReader(r, int64(size)) - rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream)) - msg, err := rw.Reader.ReadString('\n') + c, err := io.Copy(stream, lr) + elapsed := time.Since(start) if err != nil { - fmt.Println("stream err", err) - return + return fmt.Errorf("failed to write out bytes: %v", err) + } + t.Logf("wrote %d bytes in %dms", c, elapsed.Milliseconds()) + return stream.Close() +} + +func streamHandler(t *testing.T) func(net.Stream) { + return func(stream net.Stream) { + t.Helper() + defer func() { + if err := stream.Close(); err != nil { + t.Error("error closing stream: ", err) + } + }() + + start := time.Now() + c, err := io.Copy(ioutil.Discard, stream) + elapsed := time.Since(start) + if err != nil { + t.Error("error reading from stream: ", err) + return + } + t.Logf("read %d bytes in %dms", c, elapsed.Milliseconds()) } - fmt.Printf("got msg: %s", msg) } From 7df3fe54e90988d808dc3ce518d0eda2d035bfd7 Mon Sep 17 00:00:00 2001 From: Yusef Napora Date: Wed, 29 Jan 2020 09:07:06 -0500 Subject: [PATCH 3/3] make sure we fill buffer when reading from conn adds fillBuffer and writeAll helpers to make sure that we're actually filling our input buffers when reading from the insecure conn, and that we're writing the entire output buffer, even if it takes multiple calls to insecure.Read or insecure.Write --- p2p/security/noise/ik_handshake.go | 4 ++-- p2p/security/noise/protocol.go | 28 +++++++++++++++------------ p2p/security/noise/util.go | 31 ++++++++++++++++++++++++++++++ p2p/security/noise/xx_handshake.go | 4 ++-- 4 files changed, 51 insertions(+), 16 deletions(-) create mode 100644 p2p/security/noise/util.go diff --git a/p2p/security/noise/ik_handshake.go b/p2p/security/noise/ik_handshake.go index d672e30a..428c4cd5 100644 --- a/p2p/security/noise/ik_handshake.go +++ b/p2p/security/noise/ik_handshake.go @@ -28,7 +28,7 @@ func (s *secureSession) ik_sendHandshakeMessage(payload []byte, initial_stage bo } // send message - _, err = s.insecure.Write(encMsgBuf) + _, err = writeAll(s.insecure, encMsgBuf) if err != nil { log.Error("ik_sendHandshakeMessage initiator=%v err=%s", s.initiator, err) return fmt.Errorf("ik_sendHandshakeMessage write to conn err=%s", err) @@ -45,7 +45,7 @@ func (s *secureSession) ik_recvHandshakeMessage(initial_stage bool) (buf []byte, buf = make([]byte, l) - _, err = s.insecure.Read(buf) + _, err = fillBuffer(buf, s.insecure) if err != nil { return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage read from conn err=%s", err) } diff --git a/p2p/security/noise/protocol.go b/p2p/security/noise/protocol.go index 48b4c41e..3d980cf5 100644 --- a/p2p/security/noise/protocol.go +++ b/p2p/security/noise/protocol.go @@ -9,14 +9,14 @@ import ( "sync" "time" - proto "github.com/gogo/protobuf/proto" + "github.com/gogo/protobuf/proto" logging "github.com/ipfs/go-log" "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" - ik "github.com/libp2p/go-libp2p-noise/ik" - pb "github.com/libp2p/go-libp2p-noise/pb" - xx "github.com/libp2p/go-libp2p-noise/xx" + "github.com/libp2p/go-libp2p-noise/ik" + "github.com/libp2p/go-libp2p-noise/pb" + "github.com/libp2p/go-libp2p-noise/xx" ) const payload_string = "noise-libp2p-static-key:" @@ -55,7 +55,8 @@ type secureSession struct { noiseKeypair *Keypair msgBuffer []byte - rwLock sync.Mutex + readLock sync.Mutex + writeLock sync.Mutex } type peerInfo struct { @@ -120,14 +121,14 @@ func (s *secureSession) NoisePrivateKey() [32]byte { func (s *secureSession) readLength() (int, error) { buf := make([]byte, 2) - _, err := s.insecure.Read(buf) + _, err := fillBuffer(buf, s.insecure) return int(binary.BigEndian.Uint16(buf)), err } func (s *secureSession) writeLength(length int) error { buf := make([]byte, 2) binary.BigEndian.PutUint16(buf, uint16(length)) - _, err := s.insecure.Write(buf) + _, err := writeAll(s.insecure, buf) return err } @@ -237,6 +238,9 @@ func (s *secureSession) LocalPublicKey() crypto.PubKey { } func (s *secureSession) Read(buf []byte) (int, error) { + s.readLock.Lock() + defer s.readLock.Unlock() + l := len(buf) // if we have previously unread bytes, and they fit into the buf, copy them over and return @@ -255,7 +259,7 @@ func (s *secureSession) Read(buf []byte) (int, error) { // read and decrypt ciphertext ciphertext := make([]byte, l) - _, err = s.insecure.Read(ciphertext) + _, err = fillBuffer(ciphertext, s.insecure) if err != nil { log.Error("read ciphertext err", err) return 0, err @@ -317,8 +321,8 @@ func (s *secureSession) SetWriteDeadline(t time.Time) error { } func (s *secureSession) Write(in []byte) (int, error) { - s.rwLock.Lock() - defer s.rwLock.Unlock() + s.writeLock.Lock() + defer s.writeLock.Unlock() writeChunk := func(in []byte) (int, error) { ciphertext, err := s.Encrypt(in) @@ -329,11 +333,11 @@ func (s *secureSession) Write(in []byte) (int, error) { err = s.writeLength(len(ciphertext)) if err != nil { - log.Error("write length err", err) + log.Error("write length err: ", err) return 0, err } - _, err = s.insecure.Write(ciphertext) + _, err = writeAll(s.insecure, ciphertext) return len(in), err } diff --git a/p2p/security/noise/util.go b/p2p/security/noise/util.go new file mode 100644 index 00000000..a20af673 --- /dev/null +++ b/p2p/security/noise/util.go @@ -0,0 +1,31 @@ +package noise + +import "io" + +// fillBuffer reads from the given reader until the given buffer +// is full +func fillBuffer(buf []byte, reader io.Reader) (int, error) { + total := 0 + for total < len(buf) { + c, err := reader.Read(buf[total:]) + if err != nil { + return total, err + } + total += c + } + return total, nil +} + +// writeAll is a helper that writes to the given io.Writer until all input data +// has been written +func writeAll(writer io.Writer, data []byte) (int, error) { + total := 0 + for total < len(data) { + c, err := writer.Write(data[total:]) + if err != nil { + return total, err + } + total += c + } + return total, nil +} diff --git a/p2p/security/noise/xx_handshake.go b/p2p/security/noise/xx_handshake.go index b32a6c12..f3c0ef1a 100644 --- a/p2p/security/noise/xx_handshake.go +++ b/p2p/security/noise/xx_handshake.go @@ -27,7 +27,7 @@ func (s *secureSession) xx_sendHandshakeMessage(payload []byte, initial_stage bo return fmt.Errorf("xx_sendHandshakeMessage write length err=%s", err) } - _, err = s.insecure.Write(encMsgBuf) + _, err = writeAll(s.insecure, encMsgBuf) if err != nil { log.Error("xx_sendHandshakeMessage initiator=%v err=%s", s.initiator, err) return fmt.Errorf("xx_sendHandshakeMessage write to conn err=%s", err) @@ -46,7 +46,7 @@ func (s *secureSession) xx_recvHandshakeMessage(initial_stage bool) (buf []byte, buf = make([]byte, l) - _, err = s.insecure.Read(buf) + _, err = fillBuffer(buf, s.insecure) if err != nil { return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage read from conn err=%s", err) }