Merge pull request #38 from libp2p/fix/stream-reset
Make sure we actually read and write as much data as we expect
This commit is contained in:
commit
3971c24bec
|
@ -28,7 +28,7 @@ func (s *secureSession) ik_sendHandshakeMessage(payload []byte, initial_stage bo
|
|||
}
|
||||
|
||||
// send message
|
||||
_, err = s.insecure.Write(encMsgBuf)
|
||||
_, err = writeAll(s.insecure, encMsgBuf)
|
||||
if err != nil {
|
||||
log.Error("ik_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
|
||||
return fmt.Errorf("ik_sendHandshakeMessage write to conn err=%s", err)
|
||||
|
@ -45,7 +45,7 @@ func (s *secureSession) ik_recvHandshakeMessage(initial_stage bool) (buf []byte,
|
|||
|
||||
buf = make([]byte, l)
|
||||
|
||||
_, err = s.insecure.Read(buf)
|
||||
_, err = fillBuffer(buf, s.insecure)
|
||||
if err != nil {
|
||||
return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage read from conn err=%s", err)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package noise
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
|
@ -12,11 +11,14 @@ import (
|
|||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
mrand "math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const testProtocolID = "/test/noise/integration"
|
||||
|
||||
func generateKey(seed int64) (crypto.PrivKey, error) {
|
||||
var r io.Reader
|
||||
if seed == 0 {
|
||||
|
@ -109,8 +111,8 @@ func TestLibp2pIntegration_NoPipes(t *testing.T) {
|
|||
|
||||
defer hb.Close()
|
||||
|
||||
ha.SetStreamHandler(ID, handleStream)
|
||||
hb.SetStreamHandler(ID, handleStream)
|
||||
ha.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||
hb.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||
|
||||
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID()))
|
||||
if err != nil {
|
||||
|
@ -129,12 +131,12 @@ func TestLibp2pIntegration_NoPipes(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stream, err := ha.NewStream(ctx, hb.ID(), ID)
|
||||
stream, err := ha.NewStream(ctx, hb.ID(), testProtocolID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = stream.Write([]byte("hello\n"))
|
||||
err = writeRandomPayloadAndClose(t, stream)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -166,8 +168,8 @@ func TestLibp2pIntegration_WithPipes(t *testing.T) {
|
|||
|
||||
defer hb.Close()
|
||||
|
||||
ha.SetStreamHandler(ID, handleStream)
|
||||
hb.SetStreamHandler(ID, handleStream)
|
||||
ha.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||
hb.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||
|
||||
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", ha.Addrs()[0].String(), ha.ID()))
|
||||
if err != nil {
|
||||
|
@ -186,12 +188,12 @@ func TestLibp2pIntegration_WithPipes(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stream, err := hb.NewStream(ctx, ha.ID(), ID)
|
||||
stream, err := hb.NewStream(ctx, ha.ID(), testProtocolID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = stream.Write([]byte("hello\n"))
|
||||
err = writeRandomPayloadAndClose(t, stream)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -223,8 +225,8 @@ func TestLibp2pIntegration_XXFallback(t *testing.T) {
|
|||
|
||||
defer hb.Close()
|
||||
|
||||
ha.SetStreamHandler(ID, handleStream)
|
||||
hb.SetStreamHandler(ID, handleStream)
|
||||
ha.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||
hb.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||
|
||||
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID()))
|
||||
if err != nil {
|
||||
|
@ -243,12 +245,12 @@ func TestLibp2pIntegration_XXFallback(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stream, err := hb.NewStream(ctx, ha.ID(), ID)
|
||||
stream, err := hb.NewStream(ctx, ha.ID(), testProtocolID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = stream.Write([]byte("hello\n"))
|
||||
err = writeRandomPayloadAndClose(t, stream)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -275,18 +277,38 @@ func TestConstrucingWithMaker(t *testing.T) {
|
|||
_ = h.Close()
|
||||
}
|
||||
|
||||
func handleStream(stream net.Stream) {
|
||||
defer func() {
|
||||
if err := stream.Close(); err != nil {
|
||||
log.Error("error closing stream", "err", err)
|
||||
}
|
||||
}()
|
||||
func writeRandomPayloadAndClose(t *testing.T, stream net.Stream) error {
|
||||
t.Helper()
|
||||
size := 1 << 24
|
||||
r := mrand.New(mrand.NewSource(42))
|
||||
start := time.Now()
|
||||
lr := io.LimitReader(r, int64(size))
|
||||
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
|
||||
msg, err := rw.Reader.ReadString('\n')
|
||||
c, err := io.Copy(stream, lr)
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
fmt.Println("stream err", err)
|
||||
return
|
||||
return fmt.Errorf("failed to write out bytes: %v", err)
|
||||
}
|
||||
t.Logf("wrote %d bytes in %dms", c, elapsed.Milliseconds())
|
||||
return stream.Close()
|
||||
}
|
||||
|
||||
func streamHandler(t *testing.T) func(net.Stream) {
|
||||
return func(stream net.Stream) {
|
||||
t.Helper()
|
||||
defer func() {
|
||||
if err := stream.Close(); err != nil {
|
||||
t.Error("error closing stream: ", err)
|
||||
}
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
c, err := io.Copy(ioutil.Discard, stream)
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
t.Error("error reading from stream: ", err)
|
||||
return
|
||||
}
|
||||
t.Logf("read %d bytes in %dms", c, elapsed.Milliseconds())
|
||||
}
|
||||
fmt.Printf("got msg: %s", msg)
|
||||
}
|
||||
|
|
|
@ -9,14 +9,14 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
proto "github.com/gogo/protobuf/proto"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
logging "github.com/ipfs/go-log"
|
||||
"github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
|
||||
ik "github.com/libp2p/go-libp2p-noise/ik"
|
||||
pb "github.com/libp2p/go-libp2p-noise/pb"
|
||||
xx "github.com/libp2p/go-libp2p-noise/xx"
|
||||
"github.com/libp2p/go-libp2p-noise/ik"
|
||||
"github.com/libp2p/go-libp2p-noise/pb"
|
||||
"github.com/libp2p/go-libp2p-noise/xx"
|
||||
)
|
||||
|
||||
const payload_string = "noise-libp2p-static-key:"
|
||||
|
@ -55,7 +55,8 @@ type secureSession struct {
|
|||
noiseKeypair *Keypair
|
||||
|
||||
msgBuffer []byte
|
||||
rwLock sync.Mutex
|
||||
readLock sync.Mutex
|
||||
writeLock sync.Mutex
|
||||
}
|
||||
|
||||
type peerInfo struct {
|
||||
|
@ -120,14 +121,14 @@ func (s *secureSession) NoisePrivateKey() [32]byte {
|
|||
|
||||
func (s *secureSession) readLength() (int, error) {
|
||||
buf := make([]byte, 2)
|
||||
_, err := s.insecure.Read(buf)
|
||||
_, err := fillBuffer(buf, s.insecure)
|
||||
return int(binary.BigEndian.Uint16(buf)), err
|
||||
}
|
||||
|
||||
func (s *secureSession) writeLength(length int) error {
|
||||
buf := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(buf, uint16(length))
|
||||
_, err := s.insecure.Write(buf)
|
||||
_, err := writeAll(s.insecure, buf)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -237,6 +238,9 @@ func (s *secureSession) LocalPublicKey() crypto.PubKey {
|
|||
}
|
||||
|
||||
func (s *secureSession) Read(buf []byte) (int, error) {
|
||||
s.readLock.Lock()
|
||||
defer s.readLock.Unlock()
|
||||
|
||||
l := len(buf)
|
||||
|
||||
// if we have previously unread bytes, and they fit into the buf, copy them over and return
|
||||
|
@ -255,7 +259,7 @@ func (s *secureSession) Read(buf []byte) (int, error) {
|
|||
|
||||
// read and decrypt ciphertext
|
||||
ciphertext := make([]byte, l)
|
||||
_, err = s.insecure.Read(ciphertext)
|
||||
_, err = fillBuffer(ciphertext, s.insecure)
|
||||
if err != nil {
|
||||
log.Error("read ciphertext err", err)
|
||||
return 0, err
|
||||
|
@ -317,8 +321,8 @@ func (s *secureSession) SetWriteDeadline(t time.Time) error {
|
|||
}
|
||||
|
||||
func (s *secureSession) Write(in []byte) (int, error) {
|
||||
s.rwLock.Lock()
|
||||
defer s.rwLock.Unlock()
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
|
||||
writeChunk := func(in []byte) (int, error) {
|
||||
ciphertext, err := s.Encrypt(in)
|
||||
|
@ -329,11 +333,11 @@ func (s *secureSession) Write(in []byte) (int, error) {
|
|||
|
||||
err = s.writeLength(len(ciphertext))
|
||||
if err != nil {
|
||||
log.Error("write length err", err)
|
||||
log.Error("write length err: ", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
_, err = s.insecure.Write(ciphertext)
|
||||
_, err = writeAll(s.insecure, ciphertext)
|
||||
return len(in), err
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
package noise
|
||||
|
||||
import "io"
|
||||
|
||||
// fillBuffer reads from the given reader until the given buffer
|
||||
// is full
|
||||
func fillBuffer(buf []byte, reader io.Reader) (int, error) {
|
||||
total := 0
|
||||
for total < len(buf) {
|
||||
c, err := reader.Read(buf[total:])
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
total += c
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// writeAll is a helper that writes to the given io.Writer until all input data
|
||||
// has been written
|
||||
func writeAll(writer io.Writer, data []byte) (int, error) {
|
||||
total := 0
|
||||
for total < len(data) {
|
||||
c, err := writer.Write(data[total:])
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
total += c
|
||||
}
|
||||
return total, nil
|
||||
}
|
|
@ -27,7 +27,7 @@ func (s *secureSession) xx_sendHandshakeMessage(payload []byte, initial_stage bo
|
|||
return fmt.Errorf("xx_sendHandshakeMessage write length err=%s", err)
|
||||
}
|
||||
|
||||
_, err = s.insecure.Write(encMsgBuf)
|
||||
_, err = writeAll(s.insecure, encMsgBuf)
|
||||
if err != nil {
|
||||
log.Error("xx_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
|
||||
return fmt.Errorf("xx_sendHandshakeMessage write to conn err=%s", err)
|
||||
|
@ -46,7 +46,7 @@ func (s *secureSession) xx_recvHandshakeMessage(initial_stage bool) (buf []byte,
|
|||
|
||||
buf = make([]byte, l)
|
||||
|
||||
_, err = s.insecure.Read(buf)
|
||||
_, err = fillBuffer(buf, s.insecure)
|
||||
if err != nil {
|
||||
return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage read from conn err=%s", err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue