Merge pull request #38 from libp2p/fix/stream-reset
Make sure we actually read and write as much data as we expect
This commit is contained in:
commit
3971c24bec
|
@ -28,7 +28,7 @@ func (s *secureSession) ik_sendHandshakeMessage(payload []byte, initial_stage bo
|
||||||
}
|
}
|
||||||
|
|
||||||
// send message
|
// send message
|
||||||
_, err = s.insecure.Write(encMsgBuf)
|
_, err = writeAll(s.insecure, encMsgBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("ik_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
|
log.Error("ik_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
|
||||||
return fmt.Errorf("ik_sendHandshakeMessage write to conn err=%s", err)
|
return fmt.Errorf("ik_sendHandshakeMessage write to conn err=%s", err)
|
||||||
|
@ -45,7 +45,7 @@ func (s *secureSession) ik_recvHandshakeMessage(initial_stage bool) (buf []byte,
|
||||||
|
|
||||||
buf = make([]byte, l)
|
buf = make([]byte, l)
|
||||||
|
|
||||||
_, err = s.insecure.Read(buf)
|
_, err = fillBuffer(buf, s.insecure)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage read from conn err=%s", err)
|
return buf, nil, false, fmt.Errorf("ik_recvHandshakeMessage read from conn err=%s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package noise
|
package noise
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -12,11 +11,14 @@ import (
|
||||||
"github.com/libp2p/go-libp2p-core/peer"
|
"github.com/libp2p/go-libp2p-core/peer"
|
||||||
ma "github.com/multiformats/go-multiaddr"
|
ma "github.com/multiformats/go-multiaddr"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
mrand "math/rand"
|
mrand "math/rand"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const testProtocolID = "/test/noise/integration"
|
||||||
|
|
||||||
func generateKey(seed int64) (crypto.PrivKey, error) {
|
func generateKey(seed int64) (crypto.PrivKey, error) {
|
||||||
var r io.Reader
|
var r io.Reader
|
||||||
if seed == 0 {
|
if seed == 0 {
|
||||||
|
@ -109,8 +111,8 @@ func TestLibp2pIntegration_NoPipes(t *testing.T) {
|
||||||
|
|
||||||
defer hb.Close()
|
defer hb.Close()
|
||||||
|
|
||||||
ha.SetStreamHandler(ID, handleStream)
|
ha.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||||
hb.SetStreamHandler(ID, handleStream)
|
hb.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||||
|
|
||||||
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID()))
|
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -129,12 +131,12 @@ func TestLibp2pIntegration_NoPipes(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stream, err := ha.NewStream(ctx, hb.ID(), ID)
|
stream, err := ha.NewStream(ctx, hb.ID(), testProtocolID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = stream.Write([]byte("hello\n"))
|
err = writeRandomPayloadAndClose(t, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -166,8 +168,8 @@ func TestLibp2pIntegration_WithPipes(t *testing.T) {
|
||||||
|
|
||||||
defer hb.Close()
|
defer hb.Close()
|
||||||
|
|
||||||
ha.SetStreamHandler(ID, handleStream)
|
ha.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||||
hb.SetStreamHandler(ID, handleStream)
|
hb.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||||
|
|
||||||
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", ha.Addrs()[0].String(), ha.ID()))
|
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", ha.Addrs()[0].String(), ha.ID()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -186,12 +188,12 @@ func TestLibp2pIntegration_WithPipes(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stream, err := hb.NewStream(ctx, ha.ID(), ID)
|
stream, err := hb.NewStream(ctx, ha.ID(), testProtocolID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = stream.Write([]byte("hello\n"))
|
err = writeRandomPayloadAndClose(t, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -223,8 +225,8 @@ func TestLibp2pIntegration_XXFallback(t *testing.T) {
|
||||||
|
|
||||||
defer hb.Close()
|
defer hb.Close()
|
||||||
|
|
||||||
ha.SetStreamHandler(ID, handleStream)
|
ha.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||||
hb.SetStreamHandler(ID, handleStream)
|
hb.SetStreamHandler(testProtocolID, streamHandler(t))
|
||||||
|
|
||||||
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID()))
|
addr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", hb.Addrs()[0].String(), hb.ID()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -243,12 +245,12 @@ func TestLibp2pIntegration_XXFallback(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stream, err := hb.NewStream(ctx, ha.ID(), ID)
|
stream, err := hb.NewStream(ctx, ha.ID(), testProtocolID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = stream.Write([]byte("hello\n"))
|
err = writeRandomPayloadAndClose(t, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -275,18 +277,38 @@ func TestConstrucingWithMaker(t *testing.T) {
|
||||||
_ = h.Close()
|
_ = h.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleStream(stream net.Stream) {
|
func writeRandomPayloadAndClose(t *testing.T, stream net.Stream) error {
|
||||||
defer func() {
|
t.Helper()
|
||||||
if err := stream.Close(); err != nil {
|
size := 1 << 24
|
||||||
log.Error("error closing stream", "err", err)
|
r := mrand.New(mrand.NewSource(42))
|
||||||
}
|
start := time.Now()
|
||||||
}()
|
lr := io.LimitReader(r, int64(size))
|
||||||
|
|
||||||
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
|
c, err := io.Copy(stream, lr)
|
||||||
msg, err := rw.Reader.ReadString('\n')
|
elapsed := time.Since(start)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("stream err", err)
|
return fmt.Errorf("failed to write out bytes: %v", err)
|
||||||
return
|
}
|
||||||
|
t.Logf("wrote %d bytes in %dms", c, elapsed.Milliseconds())
|
||||||
|
return stream.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamHandler(t *testing.T) func(net.Stream) {
|
||||||
|
return func(stream net.Stream) {
|
||||||
|
t.Helper()
|
||||||
|
defer func() {
|
||||||
|
if err := stream.Close(); err != nil {
|
||||||
|
t.Error("error closing stream: ", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
c, err := io.Copy(ioutil.Discard, stream)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("error reading from stream: ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Logf("read %d bytes in %dms", c, elapsed.Milliseconds())
|
||||||
}
|
}
|
||||||
fmt.Printf("got msg: %s", msg)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,14 +9,14 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
proto "github.com/gogo/protobuf/proto"
|
"github.com/gogo/protobuf/proto"
|
||||||
logging "github.com/ipfs/go-log"
|
logging "github.com/ipfs/go-log"
|
||||||
"github.com/libp2p/go-libp2p-core/crypto"
|
"github.com/libp2p/go-libp2p-core/crypto"
|
||||||
"github.com/libp2p/go-libp2p-core/peer"
|
"github.com/libp2p/go-libp2p-core/peer"
|
||||||
|
|
||||||
ik "github.com/libp2p/go-libp2p-noise/ik"
|
"github.com/libp2p/go-libp2p-noise/ik"
|
||||||
pb "github.com/libp2p/go-libp2p-noise/pb"
|
"github.com/libp2p/go-libp2p-noise/pb"
|
||||||
xx "github.com/libp2p/go-libp2p-noise/xx"
|
"github.com/libp2p/go-libp2p-noise/xx"
|
||||||
)
|
)
|
||||||
|
|
||||||
const payload_string = "noise-libp2p-static-key:"
|
const payload_string = "noise-libp2p-static-key:"
|
||||||
|
@ -55,7 +55,8 @@ type secureSession struct {
|
||||||
noiseKeypair *Keypair
|
noiseKeypair *Keypair
|
||||||
|
|
||||||
msgBuffer []byte
|
msgBuffer []byte
|
||||||
rwLock sync.Mutex
|
readLock sync.Mutex
|
||||||
|
writeLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type peerInfo struct {
|
type peerInfo struct {
|
||||||
|
@ -120,14 +121,14 @@ func (s *secureSession) NoisePrivateKey() [32]byte {
|
||||||
|
|
||||||
func (s *secureSession) readLength() (int, error) {
|
func (s *secureSession) readLength() (int, error) {
|
||||||
buf := make([]byte, 2)
|
buf := make([]byte, 2)
|
||||||
_, err := s.insecure.Read(buf)
|
_, err := fillBuffer(buf, s.insecure)
|
||||||
return int(binary.BigEndian.Uint16(buf)), err
|
return int(binary.BigEndian.Uint16(buf)), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *secureSession) writeLength(length int) error {
|
func (s *secureSession) writeLength(length int) error {
|
||||||
buf := make([]byte, 2)
|
buf := make([]byte, 2)
|
||||||
binary.BigEndian.PutUint16(buf, uint16(length))
|
binary.BigEndian.PutUint16(buf, uint16(length))
|
||||||
_, err := s.insecure.Write(buf)
|
_, err := writeAll(s.insecure, buf)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -237,6 +238,9 @@ func (s *secureSession) LocalPublicKey() crypto.PubKey {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *secureSession) Read(buf []byte) (int, error) {
|
func (s *secureSession) Read(buf []byte) (int, error) {
|
||||||
|
s.readLock.Lock()
|
||||||
|
defer s.readLock.Unlock()
|
||||||
|
|
||||||
l := len(buf)
|
l := len(buf)
|
||||||
|
|
||||||
// if we have previously unread bytes, and they fit into the buf, copy them over and return
|
// if we have previously unread bytes, and they fit into the buf, copy them over and return
|
||||||
|
@ -255,7 +259,7 @@ func (s *secureSession) Read(buf []byte) (int, error) {
|
||||||
|
|
||||||
// read and decrypt ciphertext
|
// read and decrypt ciphertext
|
||||||
ciphertext := make([]byte, l)
|
ciphertext := make([]byte, l)
|
||||||
_, err = s.insecure.Read(ciphertext)
|
_, err = fillBuffer(ciphertext, s.insecure)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("read ciphertext err", err)
|
log.Error("read ciphertext err", err)
|
||||||
return 0, err
|
return 0, err
|
||||||
|
@ -317,8 +321,8 @@ func (s *secureSession) SetWriteDeadline(t time.Time) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *secureSession) Write(in []byte) (int, error) {
|
func (s *secureSession) Write(in []byte) (int, error) {
|
||||||
s.rwLock.Lock()
|
s.writeLock.Lock()
|
||||||
defer s.rwLock.Unlock()
|
defer s.writeLock.Unlock()
|
||||||
|
|
||||||
writeChunk := func(in []byte) (int, error) {
|
writeChunk := func(in []byte) (int, error) {
|
||||||
ciphertext, err := s.Encrypt(in)
|
ciphertext, err := s.Encrypt(in)
|
||||||
|
@ -329,11 +333,11 @@ func (s *secureSession) Write(in []byte) (int, error) {
|
||||||
|
|
||||||
err = s.writeLength(len(ciphertext))
|
err = s.writeLength(len(ciphertext))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("write length err", err)
|
log.Error("write length err: ", err)
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.insecure.Write(ciphertext)
|
_, err = writeAll(s.insecure, ciphertext)
|
||||||
return len(in), err
|
return len(in), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
package noise
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
// fillBuffer reads from the given reader until the given buffer
|
||||||
|
// is full
|
||||||
|
func fillBuffer(buf []byte, reader io.Reader) (int, error) {
|
||||||
|
total := 0
|
||||||
|
for total < len(buf) {
|
||||||
|
c, err := reader.Read(buf[total:])
|
||||||
|
if err != nil {
|
||||||
|
return total, err
|
||||||
|
}
|
||||||
|
total += c
|
||||||
|
}
|
||||||
|
return total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeAll is a helper that writes to the given io.Writer until all input data
|
||||||
|
// has been written
|
||||||
|
func writeAll(writer io.Writer, data []byte) (int, error) {
|
||||||
|
total := 0
|
||||||
|
for total < len(data) {
|
||||||
|
c, err := writer.Write(data[total:])
|
||||||
|
if err != nil {
|
||||||
|
return total, err
|
||||||
|
}
|
||||||
|
total += c
|
||||||
|
}
|
||||||
|
return total, nil
|
||||||
|
}
|
|
@ -27,7 +27,7 @@ func (s *secureSession) xx_sendHandshakeMessage(payload []byte, initial_stage bo
|
||||||
return fmt.Errorf("xx_sendHandshakeMessage write length err=%s", err)
|
return fmt.Errorf("xx_sendHandshakeMessage write length err=%s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.insecure.Write(encMsgBuf)
|
_, err = writeAll(s.insecure, encMsgBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("xx_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
|
log.Error("xx_sendHandshakeMessage initiator=%v err=%s", s.initiator, err)
|
||||||
return fmt.Errorf("xx_sendHandshakeMessage write to conn err=%s", err)
|
return fmt.Errorf("xx_sendHandshakeMessage write to conn err=%s", err)
|
||||||
|
@ -46,7 +46,7 @@ func (s *secureSession) xx_recvHandshakeMessage(initial_stage bool) (buf []byte,
|
||||||
|
|
||||||
buf = make([]byte, l)
|
buf = make([]byte, l)
|
||||||
|
|
||||||
_, err = s.insecure.Read(buf)
|
_, err = fillBuffer(buf, s.insecure)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage read from conn err=%s", err)
|
return buf, nil, false, fmt.Errorf("xx_recvHandshakeMessage read from conn err=%s", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue