Fix stream buffering error

This commit is contained in:
Cole Brown 2018-10-02 13:40:26 -04:00
parent 2f0ae0c5a7
commit 2f67264507
2 changed files with 59 additions and 13 deletions

View File

@ -1,7 +1,10 @@
package p2pclient
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
@ -35,6 +38,37 @@ func converStreamInfo(info *pb.StreamInfo) (*StreamInfo, error) {
return streamInfo, nil
}
type byteReaderConn struct {
io.Reader
}
func (c *byteReaderConn) ReadByte() (byte, error) {
b := make([]byte, 1)
_, err := c.Reader.Read(b)
if err != nil {
return 0, err
}
return b[0], nil
}
func readHeader(r net.Conn) (*bytes.Buffer, error) {
len, err := binary.ReadUvarint(&byteReaderConn{r})
if err != nil {
return nil, err
}
outbuf := make([]byte, 8)
sz := binary.PutUvarint(outbuf, len)
out := bytes.NewBuffer(outbuf[0:sz])
n, err := io.CopyN(out, r, int64(len))
if err != nil {
return nil, err
}
if n != int64(len) {
return nil, fmt.Errorf("read incorrect number of bytes in header: expected %d, got %d", len, n)
}
return out, nil
}
// NewStream initializes a new stream on one of the protocols in protos with
// the specified peer.
func (c *Client) NewStream(peer peer.ID, protos []string) (*StreamInfo, io.ReadWriteCloser, error) {
@ -42,7 +76,6 @@ func (c *Client) NewStream(peer peer.ID, protos []string) (*StreamInfo, io.ReadW
if err != nil {
return nil, nil, err
}
r := ggio.NewDelimitedReader(control, MessageSizeMax)
w := ggio.NewDelimitedWriter(control)
req := &pb.Request{
@ -58,6 +91,12 @@ func (c *Client) NewStream(peer peer.ID, protos []string) (*StreamInfo, io.ReadW
return nil, nil, err
}
headerbuf, err := readHeader(control)
if err != nil {
control.Close()
return nil, nil, err
}
r := ggio.NewDelimitedReader(headerbuf, MessageSizeMax)
res := &pb.Response{}
if err = r.ReadMsg(res); err != nil {
control.Close()
@ -94,7 +133,13 @@ func (c *Client) streamDispatcher() {
return
}
r := ggio.NewDelimitedReader(conn, MessageSizeMax)
headerbuf, err := readHeader(conn)
if err != nil {
log.Errorf("reading stream header: %s", err)
conn.Close()
continue
}
r := ggio.NewDelimitedReader(headerbuf, MessageSizeMax)
pbStreamInfo := &pb.StreamInfo{}
if err = r.ReadMsg(pbStreamInfo); err != nil {
log.Errorf("reading stream info: %s", err)

View File

@ -111,12 +111,15 @@ func TestStreams(t *testing.T) {
done := make(chan struct{})
c1.NewStreamHandler(testprotos, func(info *p2pclient.StreamInfo, conn io.ReadWriteCloser) {
defer conn.Close()
var buf []byte
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil && n < 4 {
if err != nil {
t.Fatal(err)
}
if string(buf) != "test" {
if n != 4 {
t.Fatal("expected to read 4 bytes")
}
if string(buf[0:4]) != "test" {
t.Fatalf(`expected "test", got "%s"`, string(buf))
}
done <- struct{}{}
@ -126,14 +129,12 @@ func TestStreams(t *testing.T) {
if err != nil {
t.Fatal(err)
}
for i := 0; i < 1000; i++ {
n, err := conn.Write([]byte("test"))
if err != nil {
t.Fatal(err)
}
if n != 4 {
t.Fatal("wrote wrong # of bytes")
}
n, err := conn.Write([]byte("test"))
if err != nil {
t.Fatal(err)
}
if n != 4 {
t.Fatal("wrote wrong # of bytes")
}
select {