Merge pull request #38 from libp2p/fix/stream-reset

Make sure we actually read and write as much data as we expect
This commit is contained in:
Yusef Napora 2020-01-29 09:25:31 -05:00 committed by GitHub
commit 3971c24bec
5 changed files with 97 additions and 40 deletions

View File

@ -28,7 +28,7 @@ func (s *secureSession) ik_sendHandshakeMessage(payload []byte, initial_stage bo
} }
// send message // send message
_, err = s.insecure.Write(encMsgBuf) _, err = writeAll(s.insecure, encMsgBuf)
if err != nil { if err != nil {
log.Error("ik_sendHandshakeMessage initiator=%v err=%s", s.initiator, err) log.Error("ik_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
return fmt.Errorf("ik_sendHandshakeMessage write to conn err=%s", err) return fmt.Errorf("ik_sendHandshakeMessage write to conn err=%s", err)
@ -45,7 +45,7 @@ func (s *secureSession) ik_recvHandshakeMessage(initial_stage bool) (buf []byte,
buf = make([]byte, l) buf = make([]byte, l)
_, err = s.insecure.Read(buf) _, err = fillBuffer(buf, s.insecure)
if err != nil { if err != nil {
return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage read from conn err=%s", err) return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage read from conn err=%s", err)
} }

View File

@ -1,7 +1,6 @@
package noise package noise
import ( import (
"bufio"
"context" "context"
"crypto/rand" "crypto/rand"
"fmt" "fmt"
@ -12,11 +11,14 @@ import (
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
"io" "io"
"io/ioutil"
mrand "math/rand" mrand "math/rand"
"testing" "testing"
"time" "time"
) )
const testProtocolID = "/test/noise/integration"
func generateKey(seed int64) (crypto.PrivKey, error) { func generateKey(seed int64) (crypto.PrivKey, error) {
var r io.Reader var r io.Reader
if seed == 0 { if seed == 0 {
@ -109,8 +111,8 @@ func TestLibp2pIntegration_NoPipes(t *testing.T) {
defer hb.Close() defer hb.Close()
ha.SetStreamHandler(ID, handleStream) ha.SetStreamHandler(testProtocolID, streamHandler(t))
hb.SetStreamHandler(ID, handleStream) hb.SetStreamHandler(testProtocolID, streamHandler(t))
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID())) addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID()))
if err != nil { if err != nil {
@ -129,12 +131,12 @@ func TestLibp2pIntegration_NoPipes(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
stream, err := ha.NewStream(ctx, hb.ID(), ID) stream, err := ha.NewStream(ctx, hb.ID(), testProtocolID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = stream.Write([]byte("hello\n")) err = writeRandomPayloadAndClose(t, stream)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -166,8 +168,8 @@ func TestLibp2pIntegration_WithPipes(t *testing.T) {
defer hb.Close() defer hb.Close()
ha.SetStreamHandler(ID, handleStream) ha.SetStreamHandler(testProtocolID, streamHandler(t))
hb.SetStreamHandler(ID, handleStream) hb.SetStreamHandler(testProtocolID, streamHandler(t))
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", ha.Addrs()[0].String(), ha.ID())) addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", ha.Addrs()[0].String(), ha.ID()))
if err != nil { if err != nil {
@ -186,12 +188,12 @@ func TestLibp2pIntegration_WithPipes(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
stream, err := hb.NewStream(ctx, ha.ID(), ID) stream, err := hb.NewStream(ctx, ha.ID(), testProtocolID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = stream.Write([]byte("hello\n")) err = writeRandomPayloadAndClose(t, stream)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -223,8 +225,8 @@ func TestLibp2pIntegration_XXFallback(t *testing.T) {
defer hb.Close() defer hb.Close()
ha.SetStreamHandler(ID, handleStream) ha.SetStreamHandler(testProtocolID, streamHandler(t))
hb.SetStreamHandler(ID, handleStream) hb.SetStreamHandler(testProtocolID, streamHandler(t))
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID())) addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID()))
if err != nil { if err != nil {
@ -243,12 +245,12 @@ func TestLibp2pIntegration_XXFallback(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
stream, err := hb.NewStream(ctx, ha.ID(), ID) stream, err := hb.NewStream(ctx, ha.ID(), testProtocolID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = stream.Write([]byte("hello\n")) err = writeRandomPayloadAndClose(t, stream)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -275,18 +277,38 @@ func TestConstrucingWithMaker(t *testing.T) {
_ = h.Close() _ = h.Close()
} }
func handleStream(stream net.Stream) { func writeRandomPayloadAndClose(t *testing.T, stream net.Stream) error {
defer func() { t.Helper()
if err := stream.Close(); err != nil { size := 1 << 24
log.Error("error closing stream", "err", err) r := mrand.New(mrand.NewSource(42))
} start := time.Now()
}() lr := io.LimitReader(r, int64(size))
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream)) c, err := io.Copy(stream, lr)
msg, err := rw.Reader.ReadString('\n') elapsed := time.Since(start)
if err != nil { if err != nil {
fmt.Println("stream err", err) return fmt.Errorf("failed to write out bytes: %v", err)
return }
t.Logf("wrote %d bytes in %dms", c, elapsed.Milliseconds())
return stream.Close()
}
func streamHandler(t *testing.T) func(net.Stream) {
return func(stream net.Stream) {
t.Helper()
defer func() {
if err := stream.Close(); err != nil {
t.Error("error closing stream: ", err)
}
}()
start := time.Now()
c, err := io.Copy(ioutil.Discard, stream)
elapsed := time.Since(start)
if err != nil {
t.Error("error reading from stream: ", err)
return
}
t.Logf("read %d bytes in %dms", c, elapsed.Milliseconds())
} }
fmt.Printf("got msg: %s", msg)
} }

View File

@ -9,14 +9,14 @@ import (
"sync" "sync"
"time" "time"
proto "github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
logging "github.com/ipfs/go-log" logging "github.com/ipfs/go-log"
"github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
ik "github.com/libp2p/go-libp2p-noise/ik" "github.com/libp2p/go-libp2p-noise/ik"
pb "github.com/libp2p/go-libp2p-noise/pb" "github.com/libp2p/go-libp2p-noise/pb"
xx "github.com/libp2p/go-libp2p-noise/xx" "github.com/libp2p/go-libp2p-noise/xx"
) )
const payload_string = "noise-libp2p-static-key:" const payload_string = "noise-libp2p-static-key:"
@ -55,7 +55,8 @@ type secureSession struct {
noiseKeypair *Keypair noiseKeypair *Keypair
msgBuffer []byte msgBuffer []byte
rwLock sync.Mutex readLock sync.Mutex
writeLock sync.Mutex
} }
type peerInfo struct { type peerInfo struct {
@ -120,14 +121,14 @@ func (s *secureSession) NoisePrivateKey() [32]byte {
func (s *secureSession) readLength() (int, error) { func (s *secureSession) readLength() (int, error) {
buf := make([]byte, 2) buf := make([]byte, 2)
_, err := s.insecure.Read(buf) _, err := fillBuffer(buf, s.insecure)
return int(binary.BigEndian.Uint16(buf)), err return int(binary.BigEndian.Uint16(buf)), err
} }
func (s *secureSession) writeLength(length int) error { func (s *secureSession) writeLength(length int) error {
buf := make([]byte, 2) buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(length)) binary.BigEndian.PutUint16(buf, uint16(length))
_, err := s.insecure.Write(buf) _, err := writeAll(s.insecure, buf)
return err return err
} }
@ -237,6 +238,9 @@ func (s *secureSession) LocalPublicKey() crypto.PubKey {
} }
func (s *secureSession) Read(buf []byte) (int, error) { func (s *secureSession) Read(buf []byte) (int, error) {
s.readLock.Lock()
defer s.readLock.Unlock()
l := len(buf) l := len(buf)
// if we have previously unread bytes, and they fit into the buf, copy them over and return // if we have previously unread bytes, and they fit into the buf, copy them over and return
@ -255,7 +259,7 @@ func (s *secureSession) Read(buf []byte) (int, error) {
// read and decrypt ciphertext // read and decrypt ciphertext
ciphertext := make([]byte, l) ciphertext := make([]byte, l)
_, err = s.insecure.Read(ciphertext) _, err = fillBuffer(ciphertext, s.insecure)
if err != nil { if err != nil {
log.Error("read ciphertext err", err) log.Error("read ciphertext err", err)
return 0, err return 0, err
@ -317,8 +321,8 @@ func (s *secureSession) SetWriteDeadline(t time.Time) error {
} }
func (s *secureSession) Write(in []byte) (int, error) { func (s *secureSession) Write(in []byte) (int, error) {
s.rwLock.Lock() s.writeLock.Lock()
defer s.rwLock.Unlock() defer s.writeLock.Unlock()
writeChunk := func(in []byte) (int, error) { writeChunk := func(in []byte) (int, error) {
ciphertext, err := s.Encrypt(in) ciphertext, err := s.Encrypt(in)
@ -329,11 +333,11 @@ func (s *secureSession) Write(in []byte) (int, error) {
err = s.writeLength(len(ciphertext)) err = s.writeLength(len(ciphertext))
if err != nil { if err != nil {
log.Error("write length err", err) log.Error("write length err: ", err)
return 0, err return 0, err
} }
_, err = s.insecure.Write(ciphertext) _, err = writeAll(s.insecure, ciphertext)
return len(in), err return len(in), err
} }

View File

@ -0,0 +1,31 @@
package noise
import "io"
// fillBuffer reads from the given reader until the given buffer
// is full
func fillBuffer(buf []byte, reader io.Reader) (int, error) {
total := 0
for total < len(buf) {
c, err := reader.Read(buf[total:])
if err != nil {
return total, err
}
total += c
}
return total, nil
}
// writeAll is a helper that writes to the given io.Writer until all input data
// has been written
func writeAll(writer io.Writer, data []byte) (int, error) {
total := 0
for total < len(data) {
c, err := writer.Write(data[total:])
if err != nil {
return total, err
}
total += c
}
return total, nil
}

View File

@ -27,7 +27,7 @@ func (s *secureSession) xx_sendHandshakeMessage(payload []byte, initial_stage bo
return fmt.Errorf("xx_sendHandshakeMessage write length err=%s", err) return fmt.Errorf("xx_sendHandshakeMessage write length err=%s", err)
} }
_, err = s.insecure.Write(encMsgBuf) _, err = writeAll(s.insecure, encMsgBuf)
if err != nil { if err != nil {
log.Error("xx_sendHandshakeMessage initiator=%v err=%s", s.initiator, err) log.Error("xx_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
return fmt.Errorf("xx_sendHandshakeMessage write to conn err=%s", err) return fmt.Errorf("xx_sendHandshakeMessage write to conn err=%s", err)
@ -46,7 +46,7 @@ func (s *secureSession) xx_recvHandshakeMessage(initial_stage bool) (buf []byte,
buf = make([]byte, l) buf = make([]byte, l)
_, err = s.insecure.Read(buf) _, err = fillBuffer(buf, s.insecure)
if err != nil { if err != nil {
return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage read from conn err=%s", err) return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage read from conn err=%s", err)
} }