184 lines
4.1 KiB
Go
184 lines
4.1 KiB
Go
package msgio
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"io"
|
|
"sync"
|
|
|
|
pool "github.com/libp2p/go-buffer-pool"
|
|
)
|
|
|
|
// varintWriter is the underlying type that implements the Writer interface.
|
|
type varintWriter struct {
|
|
W io.Writer
|
|
|
|
lbuf []byte // for encoding varints
|
|
lock sync.Locker // for threadsafe writes
|
|
}
|
|
|
|
// NewVarintWriter wraps an io.Writer with a varint msgio framed writer.
|
|
// The msgio.Writer will write the length prefix of every message written
|
|
// as a varint, using https://golang.org/pkg/encoding/binary/#PutUvarint
|
|
func NewVarintWriter(w io.Writer) WriteCloser {
|
|
return &varintWriter{
|
|
W: w,
|
|
lbuf: make([]byte, binary.MaxVarintLen64),
|
|
lock: new(sync.Mutex),
|
|
}
|
|
}
|
|
|
|
func (s *varintWriter) Write(msg []byte) (int, error) {
|
|
err := s.WriteMsg(msg)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return len(msg), nil
|
|
}
|
|
|
|
func (s *varintWriter) WriteMsg(msg []byte) error {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
|
|
length := uint64(len(msg))
|
|
n := binary.PutUvarint(s.lbuf, length)
|
|
if _, err := s.W.Write(s.lbuf[:n]); err != nil {
|
|
return err
|
|
}
|
|
_, err := s.W.Write(msg)
|
|
return err
|
|
}
|
|
|
|
func (s *varintWriter) Close() error {
|
|
if c, ok := s.W.(io.Closer); ok {
|
|
return c.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// varintReader is the underlying type that implements the Reader interface.
|
|
type varintReader struct {
|
|
R io.Reader
|
|
br io.ByteReader // for reading varints.
|
|
|
|
lbuf []byte
|
|
next int
|
|
pool *pool.BufferPool
|
|
lock sync.Locker
|
|
max int // the maximal message size (in bytes) this reader handles
|
|
}
|
|
|
|
// NewVarintReader wraps an io.Reader with a varint msgio framed reader.
|
|
// The msgio.Reader will read whole messages at a time (using the length).
|
|
// Varints read according to https://golang.org/pkg/encoding/binary/#ReadUvarint
|
|
// Assumes an equivalent writer on the other side.
|
|
func NewVarintReader(r io.Reader) ReadCloser {
|
|
return NewVarintReaderWithPool(r, pool.GlobalPool)
|
|
}
|
|
|
|
// NewVarintReaderWithPool wraps an io.Reader with a varint msgio framed reader.
|
|
// The msgio.Reader will read whole messages at a time (using the length).
|
|
// Varints read according to https://golang.org/pkg/encoding/binary/#ReadUvarint
|
|
// Assumes an equivalent writer on the other side. It uses a given
|
|
// pool.BufferPool.
|
|
func NewVarintReaderWithPool(r io.Reader, p *pool.BufferPool) ReadCloser {
|
|
if p == nil {
|
|
panic("nil pool")
|
|
}
|
|
return &varintReader{
|
|
R: r,
|
|
br: &simpleByteReader{R: r},
|
|
lbuf: make([]byte, binary.MaxVarintLen64),
|
|
next: -1,
|
|
pool: p,
|
|
lock: new(sync.Mutex),
|
|
max: defaultMaxSize,
|
|
}
|
|
}
|
|
|
|
// NextMsgLen reads the length of the next msg into s.lbuf, and returns it.
|
|
// WARNING: like Read, NextMsgLen is destructive. It reads from the internal
|
|
// reader.
|
|
func (s *varintReader) NextMsgLen() (int, error) {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
return s.nextMsgLen()
|
|
}
|
|
|
|
func (s *varintReader) nextMsgLen() (int, error) {
|
|
if s.next == -1 {
|
|
length, err := binary.ReadUvarint(s.br)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
s.next = int(length)
|
|
}
|
|
return s.next, nil
|
|
}
|
|
|
|
func (s *varintReader) Read(msg []byte) (int, error) {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
|
|
length, err := s.nextMsgLen()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if length > len(msg) {
|
|
return 0, io.ErrShortBuffer
|
|
}
|
|
_, err = io.ReadFull(s.R, msg[:length])
|
|
s.next = -1 // signal we've consumed this msg
|
|
return length, err
|
|
}
|
|
|
|
func (s *varintReader) ReadMsg() ([]byte, error) {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
|
|
length, err := s.nextMsgLen()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if length == 0 {
|
|
s.next = -1
|
|
return nil, nil
|
|
}
|
|
|
|
if length > s.max {
|
|
return nil, ErrMsgTooLarge
|
|
}
|
|
|
|
msg := s.pool.Get(length)
|
|
_, err = io.ReadFull(s.R, msg)
|
|
s.next = -1 // signal we've consumed this msg
|
|
return msg, err
|
|
}
|
|
|
|
func (s *varintReader) ReleaseMsg(msg []byte) {
|
|
s.pool.Put(msg)
|
|
}
|
|
|
|
func (s *varintReader) Close() error {
|
|
if c, ok := s.R.(io.Closer); ok {
|
|
return c.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type simpleByteReader struct {
|
|
R io.Reader
|
|
buf []byte
|
|
}
|
|
|
|
func (r *simpleByteReader) ReadByte() (c byte, err error) {
|
|
if r.buf == nil {
|
|
r.buf = make([]byte, 1)
|
|
}
|
|
|
|
if _, err := io.ReadFull(r.R, r.buf); err != nil {
|
|
return 0, err
|
|
}
|
|
return r.buf[0], nil
|
|
}
|