189 lines
4.2 KiB
Go
189 lines
4.2 KiB
Go
|
package msgio
|
||
|
|
||
|
import (
|
||
|
"encoding/binary"
|
||
|
"io"
|
||
|
"sync"
|
||
|
|
||
|
mpool "github.com/libp2p/go-msgio/mpool"
|
||
|
)
|
||
|
|
||
|
// varintWriter is the underlying type that implements the Writer interface.
|
||
|
type varintWriter struct {
|
||
|
W io.Writer
|
||
|
|
||
|
lbuf []byte // for encoding varints
|
||
|
lock sync.Locker // for threadsafe writes
|
||
|
}
|
||
|
|
||
|
// NewVarintWriter wraps an io.Writer with a varint msgio framed writer.
|
||
|
// The msgio.Writer will write the length prefix of every message written
|
||
|
// as a varint, using https://golang.org/pkg/encoding/binary/#PutUvarint
|
||
|
func NewVarintWriter(w io.Writer) WriteCloser {
|
||
|
return &varintWriter{
|
||
|
W: w,
|
||
|
lbuf: make([]byte, binary.MaxVarintLen64),
|
||
|
lock: new(sync.Mutex),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *varintWriter) Write(msg []byte) (int, error) {
|
||
|
err := s.WriteMsg(msg)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
return len(msg), nil
|
||
|
}
|
||
|
|
||
|
func (s *varintWriter) WriteMsg(msg []byte) error {
|
||
|
s.lock.Lock()
|
||
|
defer s.lock.Unlock()
|
||
|
|
||
|
length := uint64(len(msg))
|
||
|
n := binary.PutUvarint(s.lbuf, length)
|
||
|
if _, err := s.W.Write(s.lbuf[:n]); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
_, err := s.W.Write(msg)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (s *varintWriter) Close() error {
|
||
|
s.lock.Lock()
|
||
|
defer s.lock.Unlock()
|
||
|
|
||
|
if c, ok := s.W.(io.Closer); ok {
|
||
|
return c.Close()
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// varintReader is the underlying type that implements the Reader interface.
|
||
|
type varintReader struct {
|
||
|
R io.Reader
|
||
|
br io.ByteReader // for reading varints.
|
||
|
|
||
|
lbuf []byte
|
||
|
next int
|
||
|
pool *mpool.Pool
|
||
|
lock sync.Locker
|
||
|
max int // the maximal message size (in bytes) this reader handles
|
||
|
}
|
||
|
|
||
|
// NewVarintReader wraps an io.Reader with a varint msgio framed reader.
|
||
|
// The msgio.Reader will read whole messages at a time (using the length).
|
||
|
// Varints read according to https://golang.org/pkg/encoding/binary/#ReadUvarint
|
||
|
// Assumes an equivalent writer on the other side.
|
||
|
func NewVarintReader(r io.Reader) ReadCloser {
|
||
|
return NewVarintReaderWithPool(r, &mpool.ByteSlicePool)
|
||
|
}
|
||
|
|
||
|
// NewVarintReaderWithPool wraps an io.Reader with a varint msgio framed reader.
|
||
|
// The msgio.Reader will read whole messages at a time (using the length).
|
||
|
// Varints read according to https://golang.org/pkg/encoding/binary/#ReadUvarint
|
||
|
// Assumes an equivalent writer on the other side. It uses a given mpool.Pool
|
||
|
func NewVarintReaderWithPool(r io.Reader, p *mpool.Pool) ReadCloser {
|
||
|
if p == nil {
|
||
|
panic("nil pool")
|
||
|
}
|
||
|
return &varintReader{
|
||
|
R: r,
|
||
|
br: &simpleByteReader{R: r},
|
||
|
lbuf: make([]byte, binary.MaxVarintLen64),
|
||
|
next: -1,
|
||
|
pool: p,
|
||
|
lock: new(sync.Mutex),
|
||
|
max: defaultMaxSize,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// NextMsgLen reads the length of the next msg into s.lbuf, and returns it.
|
||
|
// WARNING: like Read, NextMsgLen is destructive. It reads from the internal
|
||
|
// reader.
|
||
|
func (s *varintReader) NextMsgLen() (int, error) {
|
||
|
s.lock.Lock()
|
||
|
defer s.lock.Unlock()
|
||
|
return s.nextMsgLen()
|
||
|
}
|
||
|
|
||
|
func (s *varintReader) nextMsgLen() (int, error) {
|
||
|
if s.next == -1 {
|
||
|
length, err := binary.ReadUvarint(s.br)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
s.next = int(length)
|
||
|
}
|
||
|
return s.next, nil
|
||
|
}
|
||
|
|
||
|
func (s *varintReader) Read(msg []byte) (int, error) {
|
||
|
s.lock.Lock()
|
||
|
defer s.lock.Unlock()
|
||
|
|
||
|
length, err := s.nextMsgLen()
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
if length > len(msg) {
|
||
|
return 0, io.ErrShortBuffer
|
||
|
}
|
||
|
_, err = io.ReadFull(s.R, msg[:length])
|
||
|
s.next = -1 // signal we've consumed this msg
|
||
|
return length, err
|
||
|
}
|
||
|
|
||
|
func (s *varintReader) ReadMsg() ([]byte, error) {
|
||
|
s.lock.Lock()
|
||
|
defer s.lock.Unlock()
|
||
|
|
||
|
length, err := s.nextMsgLen()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if length > s.max {
|
||
|
return nil, ErrMsgTooLarge
|
||
|
}
|
||
|
|
||
|
msgb := s.pool.Get(uint32(length))
|
||
|
if msgb == nil {
|
||
|
return nil, io.ErrShortBuffer
|
||
|
}
|
||
|
msg := msgb.([]byte)[:length]
|
||
|
_, err = io.ReadFull(s.R, msg)
|
||
|
s.next = -1 // signal we've consumed this msg
|
||
|
return msg, err
|
||
|
}
|
||
|
|
||
|
func (s *varintReader) ReleaseMsg(msg []byte) {
|
||
|
s.pool.Put(uint32(cap(msg)), msg)
|
||
|
}
|
||
|
|
||
|
func (s *varintReader) Close() error {
|
||
|
s.lock.Lock()
|
||
|
defer s.lock.Unlock()
|
||
|
|
||
|
if c, ok := s.R.(io.Closer); ok {
|
||
|
return c.Close()
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type simpleByteReader struct {
|
||
|
R io.Reader
|
||
|
buf []byte
|
||
|
}
|
||
|
|
||
|
func (r *simpleByteReader) ReadByte() (c byte, err error) {
|
||
|
if r.buf == nil {
|
||
|
r.buf = make([]byte, 1)
|
||
|
}
|
||
|
|
||
|
if _, err := io.ReadFull(r.R, r.buf); err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
return r.buf[0], nil
|
||
|
}
|