status-go/vendor/github.com/libp2p/go-msgio/varint.go

189 lines
4.2 KiB
Go
Raw Normal View History

package msgio
import (
"encoding/binary"
"io"
"sync"
mpool "github.com/libp2p/go-msgio/mpool"
)
// varintWriter is the underlying type that implements the Writer interface.
type varintWriter struct {
W io.Writer
lbuf []byte // for encoding varints
lock sync.Locker // for threadsafe writes
}
// NewVarintWriter wraps an io.Writer with a varint msgio framed writer.
// The msgio.Writer will write the length prefix of every message written
// as a varint, using https://golang.org/pkg/encoding/binary/#PutUvarint
func NewVarintWriter(w io.Writer) WriteCloser {
return &varintWriter{
W: w,
lbuf: make([]byte, binary.MaxVarintLen64),
lock: new(sync.Mutex),
}
}
func (s *varintWriter) Write(msg []byte) (int, error) {
err := s.WriteMsg(msg)
if err != nil {
return 0, err
}
return len(msg), nil
}
func (s *varintWriter) WriteMsg(msg []byte) error {
s.lock.Lock()
defer s.lock.Unlock()
length := uint64(len(msg))
n := binary.PutUvarint(s.lbuf, length)
if _, err := s.W.Write(s.lbuf[:n]); err != nil {
return err
}
_, err := s.W.Write(msg)
return err
}
func (s *varintWriter) Close() error {
s.lock.Lock()
defer s.lock.Unlock()
if c, ok := s.W.(io.Closer); ok {
return c.Close()
}
return nil
}
// varintReader is the underlying type that implements the Reader interface.
type varintReader struct {
R io.Reader
br io.ByteReader // for reading varints.
lbuf []byte
next int
pool *mpool.Pool
lock sync.Locker
max int // the maximal message size (in bytes) this reader handles
}
// NewVarintReader wraps an io.Reader with a varint msgio framed reader.
// The msgio.Reader will read whole messages at a time (using the length).
// Varints read according to https://golang.org/pkg/encoding/binary/#ReadUvarint
// Assumes an equivalent writer on the other side.
func NewVarintReader(r io.Reader) ReadCloser {
return NewVarintReaderWithPool(r, &mpool.ByteSlicePool)
}
// NewVarintReaderWithPool wraps an io.Reader with a varint msgio framed reader.
// The msgio.Reader will read whole messages at a time (using the length).
// Varints read according to https://golang.org/pkg/encoding/binary/#ReadUvarint
// Assumes an equivalent writer on the other side. It uses a given mpool.Pool
func NewVarintReaderWithPool(r io.Reader, p *mpool.Pool) ReadCloser {
if p == nil {
panic("nil pool")
}
return &varintReader{
R: r,
br: &simpleByteReader{R: r},
lbuf: make([]byte, binary.MaxVarintLen64),
next: -1,
pool: p,
lock: new(sync.Mutex),
max: defaultMaxSize,
}
}
// NextMsgLen reads the length of the next msg into s.lbuf, and returns it.
// WARNING: like Read, NextMsgLen is destructive. It reads from the internal
// reader.
func (s *varintReader) NextMsgLen() (int, error) {
s.lock.Lock()
defer s.lock.Unlock()
return s.nextMsgLen()
}
func (s *varintReader) nextMsgLen() (int, error) {
if s.next == -1 {
length, err := binary.ReadUvarint(s.br)
if err != nil {
return 0, err
}
s.next = int(length)
}
return s.next, nil
}
func (s *varintReader) Read(msg []byte) (int, error) {
s.lock.Lock()
defer s.lock.Unlock()
length, err := s.nextMsgLen()
if err != nil {
return 0, err
}
if length > len(msg) {
return 0, io.ErrShortBuffer
}
_, err = io.ReadFull(s.R, msg[:length])
s.next = -1 // signal we've consumed this msg
return length, err
}
func (s *varintReader) ReadMsg() ([]byte, error) {
s.lock.Lock()
defer s.lock.Unlock()
length, err := s.nextMsgLen()
if err != nil {
return nil, err
}
if length > s.max {
return nil, ErrMsgTooLarge
}
msgb := s.pool.Get(uint32(length))
if msgb == nil {
return nil, io.ErrShortBuffer
}
msg := msgb.([]byte)[:length]
_, err = io.ReadFull(s.R, msg)
s.next = -1 // signal we've consumed this msg
return msg, err
}
func (s *varintReader) ReleaseMsg(msg []byte) {
s.pool.Put(uint32(cap(msg)), msg)
}
func (s *varintReader) Close() error {
s.lock.Lock()
defer s.lock.Unlock()
if c, ok := s.R.(io.Closer); ok {
return c.Close()
}
return nil
}
type simpleByteReader struct {
R io.Reader
buf []byte
}
func (r *simpleByteReader) ReadByte() (c byte, err error) {
if r.buf == nil {
r.buf = make([]byte, 1)
}
if _, err := io.ReadFull(r.R, r.buf); err != nil {
return 0, err
}
return r.buf[0], nil
}