Fix stream buffering error
This commit is contained in:
parent
2f0ae0c5a7
commit
2f67264507
|
@ -1,7 +1,10 @@
|
|||
package p2pclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
|
@ -35,6 +38,37 @@ func converStreamInfo(info *pb.StreamInfo) (*StreamInfo, error) {
|
|||
return streamInfo, nil
|
||||
}
|
||||
|
||||
type byteReaderConn struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (c *byteReaderConn) ReadByte() (byte, error) {
|
||||
b := make([]byte, 1)
|
||||
_, err := c.Reader.Read(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return b[0], nil
|
||||
}
|
||||
|
||||
func readHeader(r net.Conn) (*bytes.Buffer, error) {
|
||||
len, err := binary.ReadUvarint(&byteReaderConn{r})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outbuf := make([]byte, 8)
|
||||
sz := binary.PutUvarint(outbuf, len)
|
||||
out := bytes.NewBuffer(outbuf[0:sz])
|
||||
n, err := io.CopyN(out, r, int64(len))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n != int64(len) {
|
||||
return nil, fmt.Errorf("read incorrect number of bytes in header: expected %d, got %d", len, n)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// NewStream initializes a new stream on one of the protocols in protos with
|
||||
// the specified peer.
|
||||
func (c *Client) NewStream(peer peer.ID, protos []string) (*StreamInfo, io.ReadWriteCloser, error) {
|
||||
|
@ -42,7 +76,6 @@ func (c *Client) NewStream(peer peer.ID, protos []string) (*StreamInfo, io.ReadW
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
r := ggio.NewDelimitedReader(control, MessageSizeMax)
|
||||
w := ggio.NewDelimitedWriter(control)
|
||||
|
||||
req := &pb.Request{
|
||||
|
@ -58,6 +91,12 @@ func (c *Client) NewStream(peer peer.ID, protos []string) (*StreamInfo, io.ReadW
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
headerbuf, err := readHeader(control)
|
||||
if err != nil {
|
||||
control.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
r := ggio.NewDelimitedReader(headerbuf, MessageSizeMax)
|
||||
res := &pb.Response{}
|
||||
if err = r.ReadMsg(res); err != nil {
|
||||
control.Close()
|
||||
|
@ -94,7 +133,13 @@ func (c *Client) streamDispatcher() {
|
|||
return
|
||||
}
|
||||
|
||||
r := ggio.NewDelimitedReader(conn, MessageSizeMax)
|
||||
headerbuf, err := readHeader(conn)
|
||||
if err != nil {
|
||||
log.Errorf("reading stream header: %s", err)
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
r := ggio.NewDelimitedReader(headerbuf, MessageSizeMax)
|
||||
pbStreamInfo := &pb.StreamInfo{}
|
||||
if err = r.ReadMsg(pbStreamInfo); err != nil {
|
||||
log.Errorf("reading stream info: %s", err)
|
||||
|
|
|
@ -111,12 +111,15 @@ func TestStreams(t *testing.T) {
|
|||
done := make(chan struct{})
|
||||
c1.NewStreamHandler(testprotos, func(info *p2pclient.StreamInfo, conn io.ReadWriteCloser) {
|
||||
defer conn.Close()
|
||||
var buf []byte
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil && n < 4 {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(buf) != "test" {
|
||||
if n != 4 {
|
||||
t.Fatal("expected to read 4 bytes")
|
||||
}
|
||||
if string(buf[0:4]) != "test" {
|
||||
t.Fatalf(`expected "test", got "%s"`, string(buf))
|
||||
}
|
||||
done <- struct{}{}
|
||||
|
@ -126,14 +129,12 @@ func TestStreams(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for i := 0; i < 1000; i++ {
|
||||
n, err := conn.Write([]byte("test"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 4 {
|
||||
t.Fatal("wrote wrong # of bytes")
|
||||
}
|
||||
n, err := conn.Write([]byte("test"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 4 {
|
||||
t.Fatal("wrote wrong # of bytes")
|
||||
}
|
||||
|
||||
select {
|
||||
|
|
Loading…
Reference in New Issue