Merge pull request #38 from libp2p/fix/stream-reset

Make sure we actually read and write as much data as we expect
This commit is contained in:
Yusef Napora 2020-01-29 09:25:31 -05:00 committed by GitHub
commit 3971c24bec
5 changed files with 97 additions and 40 deletions

View File

@ -28,7 +28,7 @@ func (s *secureSession) ik_sendHandshakeMessage(payload []byte, initial_stage bo
}
// send message
_, err = s.insecure.Write(encMsgBuf)
_, err = writeAll(s.insecure, encMsgBuf)
if err != nil {
log.Error("ik_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
return fmt.Errorf("ik_sendHandshakeMessage write to conn err=%s", err)
@ -45,7 +45,7 @@ func (s *secureSession) ik_recvHandshakeMessage(initial_stage bool) (buf []byte,
buf = make([]byte, l)
_, err = s.insecure.Read(buf)
_, err = fillBuffer(buf, s.insecure)
if err != nil {
return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage read from conn err=%s", err)
}

View File

@ -1,7 +1,6 @@
package noise
import (
"bufio"
"context"
"crypto/rand"
"fmt"
@ -12,11 +11,14 @@ import (
"github.com/libp2p/go-libp2p-core/peer"
ma "github.com/multiformats/go-multiaddr"
"io"
"io/ioutil"
mrand "math/rand"
"testing"
"time"
)
const testProtocolID = "/test/noise/integration"
func generateKey(seed int64) (crypto.PrivKey, error) {
var r io.Reader
if seed == 0 {
@ -109,8 +111,8 @@ func TestLibp2pIntegration_NoPipes(t *testing.T) {
defer hb.Close()
ha.SetStreamHandler(ID, handleStream)
hb.SetStreamHandler(ID, handleStream)
ha.SetStreamHandler(testProtocolID, streamHandler(t))
hb.SetStreamHandler(testProtocolID, streamHandler(t))
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID()))
if err != nil {
@ -129,12 +131,12 @@ func TestLibp2pIntegration_NoPipes(t *testing.T) {
t.Fatal(err)
}
stream, err := ha.NewStream(ctx, hb.ID(), ID)
stream, err := ha.NewStream(ctx, hb.ID(), testProtocolID)
if err != nil {
t.Fatal(err)
}
_, err = stream.Write([]byte("hello\n"))
err = writeRandomPayloadAndClose(t, stream)
if err != nil {
t.Fatal(err)
}
@ -166,8 +168,8 @@ func TestLibp2pIntegration_WithPipes(t *testing.T) {
defer hb.Close()
ha.SetStreamHandler(ID, handleStream)
hb.SetStreamHandler(ID, handleStream)
ha.SetStreamHandler(testProtocolID, streamHandler(t))
hb.SetStreamHandler(testProtocolID, streamHandler(t))
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", ha.Addrs()[0].String(), ha.ID()))
if err != nil {
@ -186,12 +188,12 @@ func TestLibp2pIntegration_WithPipes(t *testing.T) {
t.Fatal(err)
}
stream, err := hb.NewStream(ctx, ha.ID(), ID)
stream, err := hb.NewStream(ctx, ha.ID(), testProtocolID)
if err != nil {
t.Fatal(err)
}
_, err = stream.Write([]byte("hello\n"))
err = writeRandomPayloadAndClose(t, stream)
if err != nil {
t.Fatal(err)
}
@ -223,8 +225,8 @@ func TestLibp2pIntegration_XXFallback(t *testing.T) {
defer hb.Close()
ha.SetStreamHandler(ID, handleStream)
hb.SetStreamHandler(ID, handleStream)
ha.SetStreamHandler(testProtocolID, streamHandler(t))
hb.SetStreamHandler(testProtocolID, streamHandler(t))
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID()))
if err != nil {
@ -243,12 +245,12 @@ func TestLibp2pIntegration_XXFallback(t *testing.T) {
t.Fatal(err)
}
stream, err := hb.NewStream(ctx, ha.ID(), ID)
stream, err := hb.NewStream(ctx, ha.ID(), testProtocolID)
if err != nil {
t.Fatal(err)
}
_, err = stream.Write([]byte("hello\n"))
err = writeRandomPayloadAndClose(t, stream)
if err != nil {
t.Fatal(err)
}
@ -275,18 +277,38 @@ func TestConstrucingWithMaker(t *testing.T) {
_ = h.Close()
}
func handleStream(stream net.Stream) {
defer func() {
if err := stream.Close(); err != nil {
log.Error("error closing stream", "err", err)
}
}()
func writeRandomPayloadAndClose(t *testing.T, stream net.Stream) error {
t.Helper()
size := 1 << 24
r := mrand.New(mrand.NewSource(42))
start := time.Now()
lr := io.LimitReader(r, int64(size))
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
msg, err := rw.Reader.ReadString('\n')
c, err := io.Copy(stream, lr)
elapsed := time.Since(start)
if err != nil {
fmt.Println("stream err", err)
return
return fmt.Errorf("failed to write out bytes: %v", err)
}
t.Logf("wrote %d bytes in %dms", c, elapsed.Milliseconds())
return stream.Close()
}
func streamHandler(t *testing.T) func(net.Stream) {
return func(stream net.Stream) {
t.Helper()
defer func() {
if err := stream.Close(); err != nil {
t.Error("error closing stream: ", err)
}
}()
start := time.Now()
c, err := io.Copy(ioutil.Discard, stream)
elapsed := time.Since(start)
if err != nil {
t.Error("error reading from stream: ", err)
return
}
t.Logf("read %d bytes in %dms", c, elapsed.Milliseconds())
}
fmt.Printf("got msg: %s", msg)
}

View File

@ -9,14 +9,14 @@ import (
"sync"
"time"
proto "github.com/gogo/protobuf/proto"
"github.com/gogo/protobuf/proto"
logging "github.com/ipfs/go-log"
"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
ik "github.com/libp2p/go-libp2p-noise/ik"
pb "github.com/libp2p/go-libp2p-noise/pb"
xx "github.com/libp2p/go-libp2p-noise/xx"
"github.com/libp2p/go-libp2p-noise/ik"
"github.com/libp2p/go-libp2p-noise/pb"
"github.com/libp2p/go-libp2p-noise/xx"
)
const payload_string = "noise-libp2p-static-key:"
@ -55,7 +55,8 @@ type secureSession struct {
noiseKeypair *Keypair
msgBuffer []byte
rwLock sync.Mutex
readLock sync.Mutex
writeLock sync.Mutex
}
type peerInfo struct {
@ -120,14 +121,14 @@ func (s *secureSession) NoisePrivateKey() [32]byte {
func (s *secureSession) readLength() (int, error) {
buf := make([]byte, 2)
_, err := s.insecure.Read(buf)
_, err := fillBuffer(buf, s.insecure)
return int(binary.BigEndian.Uint16(buf)), err
}
func (s *secureSession) writeLength(length int) error {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(length))
_, err := s.insecure.Write(buf)
_, err := writeAll(s.insecure, buf)
return err
}
@ -237,6 +238,9 @@ func (s *secureSession) LocalPublicKey() crypto.PubKey {
}
func (s *secureSession) Read(buf []byte) (int, error) {
s.readLock.Lock()
defer s.readLock.Unlock()
l := len(buf)
// if we have previously unread bytes, and they fit into the buf, copy them over and return
@ -255,7 +259,7 @@ func (s *secureSession) Read(buf []byte) (int, error) {
// read and decrypt ciphertext
ciphertext := make([]byte, l)
_, err = s.insecure.Read(ciphertext)
_, err = fillBuffer(ciphertext, s.insecure)
if err != nil {
log.Error("read ciphertext err", err)
return 0, err
@ -317,8 +321,8 @@ func (s *secureSession) SetWriteDeadline(t time.Time) error {
}
func (s *secureSession) Write(in []byte) (int, error) {
s.rwLock.Lock()
defer s.rwLock.Unlock()
s.writeLock.Lock()
defer s.writeLock.Unlock()
writeChunk := func(in []byte) (int, error) {
ciphertext, err := s.Encrypt(in)
@ -329,11 +333,11 @@ func (s *secureSession) Write(in []byte) (int, error) {
err = s.writeLength(len(ciphertext))
if err != nil {
log.Error("write length err", err)
log.Error("write length err: ", err)
return 0, err
}
_, err = s.insecure.Write(ciphertext)
_, err = writeAll(s.insecure, ciphertext)
return len(in), err
}

View File

@ -0,0 +1,31 @@
package noise
import "io"
// fillBuffer reads from the given reader until the given buffer
// is full
func fillBuffer(buf []byte, reader io.Reader) (int, error) {
total := 0
for total < len(buf) {
c, err := reader.Read(buf[total:])
if err != nil {
return total, err
}
total += c
}
return total, nil
}
// writeAll is a helper that writes to the given io.Writer until all input data
// has been written
func writeAll(writer io.Writer, data []byte) (int, error) {
total := 0
for total < len(data) {
c, err := writer.Write(data[total:])
if err != nil {
return total, err
}
total += c
}
return total, nil
}

View File

@ -27,7 +27,7 @@ func (s *secureSession) xx_sendHandshakeMessage(payload []byte, initial_stage bo
return fmt.Errorf("xx_sendHandshakeMessage write length err=%s", err)
}
_, err = s.insecure.Write(encMsgBuf)
_, err = writeAll(s.insecure, encMsgBuf)
if err != nil {
log.Error("xx_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
return fmt.Errorf("xx_sendHandshakeMessage write to conn err=%s", err)
@ -46,7 +46,7 @@ func (s *secureSession) xx_recvHandshakeMessage(initial_stage bool) (buf []byte,
buf = make([]byte, l)
_, err = s.insecure.Read(buf)
_, err = fillBuffer(buf, s.insecure)
if err != nil {
return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage read from conn err=%s", err)
}