diff --git a/p2pclient/streams.go b/p2pclient/streams.go index 59dd193..08ef049 100644 --- a/p2pclient/streams.go +++ b/p2pclient/streams.go @@ -1,7 +1,10 @@ package p2pclient import ( + "bytes" + "encoding/binary" "errors" + "fmt" "io" "net" @@ -35,6 +38,37 @@ func converStreamInfo(info *pb.StreamInfo) (*StreamInfo, error) { return streamInfo, nil } +type byteReaderConn struct { + io.Reader +} + +func (c *byteReaderConn) ReadByte() (byte, error) { + b := make([]byte, 1) + _, err := c.Reader.Read(b) + if err != nil { + return 0, err + } + return b[0], nil +} + +func readHeader(r net.Conn) (*bytes.Buffer, error) { + len, err := binary.ReadUvarint(&byteReaderConn{r}) + if err != nil { + return nil, err + } + outbuf := make([]byte, 8) + sz := binary.PutUvarint(outbuf, len) + out := bytes.NewBuffer(outbuf[0:sz]) + n, err := io.CopyN(out, r, int64(len)) + if err != nil { + return nil, err + } + if n != int64(len) { + return nil, fmt.Errorf("read incorrect number of bytes in header: expected %d, got %d", len, n) + } + return out, nil +} + // NewStream initializes a new stream on one of the protocols in protos with // the specified peer. func (c *Client) NewStream(peer peer.ID, protos []string) (*StreamInfo, io.ReadWriteCloser, error) { @@ -42,7 +76,6 @@ func (c *Client) NewStream(peer peer.ID, protos []string) (*StreamInfo, io.ReadW if err != nil { return nil, nil, err } - r := ggio.NewDelimitedReader(control, MessageSizeMax) w := ggio.NewDelimitedWriter(control) req := &pb.Request{ @@ -58,6 +91,12 @@ func (c *Client) NewStream(peer peer.ID, protos []string) (*StreamInfo, io.ReadW return nil, nil, err } + headerbuf, err := readHeader(control) + if err != nil { + control.Close() + return nil, nil, err + } + r := ggio.NewDelimitedReader(headerbuf, MessageSizeMax) res := &pb.Response{} if err = r.ReadMsg(res); err != nil { control.Close() @@ -94,7 +133,13 @@ func (c *Client) streamDispatcher() { return } - r := ggio.NewDelimitedReader(conn, MessageSizeMax) + headerbuf, err := readHeader(conn) + if err != nil { + log.Errorf("reading stream header: %s", err) + conn.Close() + continue + } + r := ggio.NewDelimitedReader(headerbuf, MessageSizeMax) pbStreamInfo := &pb.StreamInfo{} if err = r.ReadMsg(pbStreamInfo); err != nil { log.Errorf("reading stream info: %s", err) diff --git a/test/integration_test.go b/test/integration_test.go index 11f8d11..f31ffb4 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -111,12 +111,15 @@ func TestStreams(t *testing.T) { done := make(chan struct{}) c1.NewStreamHandler(testprotos, func(info *p2pclient.StreamInfo, conn io.ReadWriteCloser) { defer conn.Close() - var buf []byte + buf := make([]byte, 1024) n, err := conn.Read(buf) - if err != nil && n < 4 { + if err != nil { t.Fatal(err) } - if string(buf) != "test" { + if n != 4 { + t.Fatal("expected to read 4 bytes") + } + if string(buf[0:4]) != "test" { t.Fatalf(`expected "test", got "%s"`, string(buf)) } done <- struct{}{} @@ -126,14 +129,12 @@ func TestStreams(t *testing.T) { if err != nil { t.Fatal(err) } - for i := 0; i < 1000; i++ { - n, err := conn.Write([]byte("test")) - if err != nil { - t.Fatal(err) - } - if n != 4 { - t.Fatal("wrote wrong # of bytes") - } + n, err := conn.Write([]byte("test")) + if err != nil { + t.Fatal(err) + } + if n != 4 { + t.Fatal("wrote wrong # of bytes") } select {