592 lines
11 KiB
Go
592 lines
11 KiB
Go
package multiplex
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
logging "github.com/ipfs/go-log"
|
|
pool "github.com/libp2p/go-buffer-pool"
|
|
)
|
|
|
|
var log = logging.Logger("mplex")
|
|
|
|
var MaxMessageSize = 1 << 20
|
|
|
|
// Max time to block waiting for a slow reader to read from a stream before
|
|
// resetting it. Preferably, we'd have some form of back-pressure mechanism but
|
|
// we don't have that in this protocol.
|
|
var ReceiveTimeout = 5 * time.Second
|
|
|
|
// ErrShutdown is returned when operating on a shutdown session
|
|
var ErrShutdown = errors.New("session shut down")
|
|
|
|
// ErrTwoInitiators is returned when both sides think they're the initiator
|
|
var ErrTwoInitiators = errors.New("two initiators")
|
|
|
|
// ErrInvalidState is returned when the other side does something it shouldn't.
|
|
// In this case, we close the connection to be safe.
|
|
var ErrInvalidState = errors.New("received an unexpected message from the peer")
|
|
|
|
var errTimeout = timeout{}
|
|
var errStreamClosed = errors.New("stream closed")
|
|
|
|
var (
|
|
NewStreamTimeout = time.Minute
|
|
ResetStreamTimeout = 2 * time.Minute
|
|
|
|
WriteCoalesceDelay = 100 * time.Microsecond
|
|
)
|
|
|
|
type timeout struct{}
|
|
|
|
func (_ timeout) Error() string {
|
|
return "i/o deadline exceeded"
|
|
}
|
|
|
|
func (_ timeout) Temporary() bool {
|
|
return true
|
|
}
|
|
|
|
func (_ timeout) Timeout() bool {
|
|
return true
|
|
}
|
|
|
|
// +1 for initiator
|
|
const (
|
|
newStreamTag = 0
|
|
messageTag = 2
|
|
closeTag = 4
|
|
resetTag = 6
|
|
)
|
|
|
|
// Multiplex is a mplex session.
|
|
type Multiplex struct {
|
|
con net.Conn
|
|
buf *bufio.Reader
|
|
nextID uint64
|
|
initiator bool
|
|
|
|
closed chan struct{}
|
|
shutdown chan struct{}
|
|
shutdownErr error
|
|
shutdownLock sync.Mutex
|
|
|
|
writeCh chan []byte
|
|
writeTimer *time.Timer
|
|
writeTimerFired bool
|
|
|
|
nstreams chan *Stream
|
|
|
|
channels map[streamID]*Stream
|
|
chLock sync.Mutex
|
|
}
|
|
|
|
// NewMultiplex creates a new multiplexer session.
|
|
func NewMultiplex(con net.Conn, initiator bool) *Multiplex {
|
|
mp := &Multiplex{
|
|
con: con,
|
|
initiator: initiator,
|
|
buf: bufio.NewReader(con),
|
|
channels: make(map[streamID]*Stream),
|
|
closed: make(chan struct{}),
|
|
shutdown: make(chan struct{}),
|
|
writeCh: make(chan []byte, 16),
|
|
writeTimer: time.NewTimer(0),
|
|
nstreams: make(chan *Stream, 16),
|
|
}
|
|
|
|
go mp.handleIncoming()
|
|
go mp.handleOutgoing()
|
|
|
|
return mp
|
|
}
|
|
|
|
func (mp *Multiplex) newStream(id streamID, name string) (s *Stream) {
|
|
s = &Stream{
|
|
id: id,
|
|
name: name,
|
|
dataIn: make(chan []byte, 8),
|
|
reset: make(chan struct{}),
|
|
rDeadline: makePipeDeadline(),
|
|
wDeadline: makePipeDeadline(),
|
|
mp: mp,
|
|
}
|
|
|
|
s.closedLocal, s.doCloseLocal = context.WithCancel(context.Background())
|
|
return
|
|
}
|
|
|
|
// Accept accepts the next stream from the connection.
|
|
func (m *Multiplex) Accept() (*Stream, error) {
|
|
select {
|
|
case s, ok := <-m.nstreams:
|
|
if !ok {
|
|
return nil, errors.New("multiplex closed")
|
|
}
|
|
return s, nil
|
|
case <-m.closed:
|
|
return nil, m.shutdownErr
|
|
}
|
|
}
|
|
|
|
// Close closes the session.
|
|
func (mp *Multiplex) Close() error {
|
|
mp.closeNoWait()
|
|
|
|
// Wait for the receive loop to finish.
|
|
<-mp.closed
|
|
|
|
return nil
|
|
}
|
|
|
|
func (mp *Multiplex) closeNoWait() {
|
|
mp.shutdownLock.Lock()
|
|
select {
|
|
case <-mp.shutdown:
|
|
default:
|
|
mp.con.Close()
|
|
close(mp.shutdown)
|
|
}
|
|
mp.shutdownLock.Unlock()
|
|
}
|
|
|
|
// IsClosed returns true if the session is closed.
|
|
func (mp *Multiplex) IsClosed() bool {
|
|
select {
|
|
case <-mp.closed:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (mp *Multiplex) sendMsg(done <-chan struct{}, header uint64, data []byte) error {
|
|
buf := pool.Get(len(data) + 20)
|
|
|
|
n := 0
|
|
n += binary.PutUvarint(buf[n:], header)
|
|
n += binary.PutUvarint(buf[n:], uint64(len(data)))
|
|
n += copy(buf[n:], data)
|
|
|
|
select {
|
|
case mp.writeCh <- buf[:n]:
|
|
return nil
|
|
case <-mp.shutdown:
|
|
return ErrShutdown
|
|
case <-done:
|
|
return errTimeout
|
|
}
|
|
}
|
|
|
|
func (mp *Multiplex) handleOutgoing() {
|
|
for {
|
|
select {
|
|
case <-mp.shutdown:
|
|
return
|
|
|
|
case data := <-mp.writeCh:
|
|
// FIXME: https://github.com/libp2p/go-libp2p/issues/644
|
|
// write coalescing disabled until this can be fixed.
|
|
//err := mp.writeMsg(data)
|
|
err := mp.doWriteMsg(data)
|
|
pool.Put(data)
|
|
if err != nil {
|
|
// the connection is closed by this time
|
|
log.Warningf("error writing data: %s", err.Error())
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mp *Multiplex) writeMsg(data []byte) error {
|
|
if len(data) >= 512 {
|
|
err := mp.doWriteMsg(data)
|
|
pool.Put(data)
|
|
return err
|
|
}
|
|
|
|
buf := pool.Get(4096)
|
|
defer pool.Put(buf)
|
|
|
|
n := copy(buf, data)
|
|
pool.Put(data)
|
|
|
|
if !mp.writeTimerFired {
|
|
if !mp.writeTimer.Stop() {
|
|
<-mp.writeTimer.C
|
|
}
|
|
}
|
|
mp.writeTimer.Reset(WriteCoalesceDelay)
|
|
mp.writeTimerFired = false
|
|
|
|
for {
|
|
select {
|
|
case data = <-mp.writeCh:
|
|
wr := copy(buf[n:], data)
|
|
if wr < len(data) {
|
|
// we filled the buffer, send it
|
|
err := mp.doWriteMsg(buf)
|
|
if err != nil {
|
|
pool.Put(data)
|
|
return err
|
|
}
|
|
|
|
if len(data)-wr >= 512 {
|
|
// the remaining data is not a small write, send it
|
|
err := mp.doWriteMsg(data[wr:])
|
|
pool.Put(data)
|
|
return err
|
|
}
|
|
|
|
n = copy(buf, data[wr:])
|
|
|
|
// we've written some, reset the timer to coalesce the rest
|
|
if !mp.writeTimer.Stop() {
|
|
<-mp.writeTimer.C
|
|
}
|
|
mp.writeTimer.Reset(WriteCoalesceDelay)
|
|
} else {
|
|
n += wr
|
|
}
|
|
|
|
pool.Put(data)
|
|
|
|
case <-mp.writeTimer.C:
|
|
mp.writeTimerFired = true
|
|
return mp.doWriteMsg(buf[:n])
|
|
|
|
case <-mp.shutdown:
|
|
return ErrShutdown
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mp *Multiplex) doWriteMsg(data []byte) error {
|
|
if mp.isShutdown() {
|
|
return ErrShutdown
|
|
}
|
|
|
|
_, err := mp.con.Write(data)
|
|
if err != nil {
|
|
mp.closeNoWait()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (mp *Multiplex) nextChanID() uint64 {
|
|
out := mp.nextID
|
|
mp.nextID++
|
|
return out
|
|
}
|
|
|
|
// NewStream creates a new stream.
|
|
func (mp *Multiplex) NewStream() (*Stream, error) {
|
|
return mp.NewNamedStream("")
|
|
}
|
|
|
|
// NewNamedStream creates a new named stream.
|
|
func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) {
|
|
mp.chLock.Lock()
|
|
|
|
// We could call IsClosed but this is faster (given that we already have
|
|
// the lock).
|
|
if mp.channels == nil {
|
|
mp.chLock.Unlock()
|
|
return nil, ErrShutdown
|
|
}
|
|
|
|
sid := mp.nextChanID()
|
|
header := (sid << 3) | newStreamTag
|
|
|
|
if name == "" {
|
|
name = fmt.Sprint(sid)
|
|
}
|
|
s := mp.newStream(streamID{
|
|
id: sid,
|
|
initiator: true,
|
|
}, name)
|
|
mp.channels[s.id] = s
|
|
mp.chLock.Unlock()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), NewStreamTimeout)
|
|
defer cancel()
|
|
|
|
err := mp.sendMsg(ctx.Done(), header, []byte(name))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (mp *Multiplex) cleanup() {
|
|
mp.closeNoWait()
|
|
mp.chLock.Lock()
|
|
defer mp.chLock.Unlock()
|
|
for _, msch := range mp.channels {
|
|
msch.clLock.Lock()
|
|
if !msch.closedRemote {
|
|
msch.closedRemote = true
|
|
// Cancel readers
|
|
close(msch.reset)
|
|
}
|
|
|
|
msch.doCloseLocal()
|
|
msch.clLock.Unlock()
|
|
}
|
|
// Don't remove this nil assignment. We check if this is nil to check if
|
|
// the connection is closed when we already have the lock (faster than
|
|
// checking if the stream is closed).
|
|
mp.channels = nil
|
|
if mp.shutdownErr == nil {
|
|
mp.shutdownErr = ErrShutdown
|
|
}
|
|
close(mp.closed)
|
|
}
|
|
|
|
func (mp *Multiplex) handleIncoming() {
|
|
defer mp.cleanup()
|
|
|
|
recvTimeout := time.NewTimer(0)
|
|
defer recvTimeout.Stop()
|
|
|
|
if !recvTimeout.Stop() {
|
|
<-recvTimeout.C
|
|
}
|
|
|
|
for {
|
|
chID, tag, err := mp.readNextHeader()
|
|
if err != nil {
|
|
mp.shutdownErr = err
|
|
return
|
|
}
|
|
|
|
remoteIsInitiator := tag&1 == 0
|
|
ch := streamID{
|
|
// true if *I'm* the initiator.
|
|
initiator: !remoteIsInitiator,
|
|
id: chID,
|
|
}
|
|
// Rounds up the tag:
|
|
// 0 -> 0
|
|
// 1 -> 2
|
|
// 2 -> 2
|
|
// 3 -> 4
|
|
// etc...
|
|
tag += (tag & 1)
|
|
|
|
b, err := mp.readNext()
|
|
if err != nil {
|
|
mp.shutdownErr = err
|
|
return
|
|
}
|
|
|
|
mp.chLock.Lock()
|
|
msch, ok := mp.channels[ch]
|
|
mp.chLock.Unlock()
|
|
|
|
switch tag {
|
|
case newStreamTag:
|
|
if ok {
|
|
log.Debugf("received NewStream message for existing stream: %d", ch)
|
|
mp.shutdownErr = ErrInvalidState
|
|
return
|
|
}
|
|
|
|
name := string(b)
|
|
pool.Put(b)
|
|
|
|
msch = mp.newStream(ch, name)
|
|
mp.chLock.Lock()
|
|
mp.channels[ch] = msch
|
|
mp.chLock.Unlock()
|
|
select {
|
|
case mp.nstreams <- msch:
|
|
case <-mp.shutdown:
|
|
return
|
|
}
|
|
|
|
case resetTag:
|
|
if !ok {
|
|
// This is *ok*. We forget the stream on reset.
|
|
continue
|
|
}
|
|
msch.clLock.Lock()
|
|
|
|
isClosed := msch.isClosed()
|
|
|
|
if !msch.closedRemote {
|
|
close(msch.reset)
|
|
msch.closedRemote = true
|
|
}
|
|
|
|
if !isClosed {
|
|
msch.doCloseLocal()
|
|
}
|
|
|
|
msch.clLock.Unlock()
|
|
|
|
msch.cancelDeadlines()
|
|
|
|
mp.chLock.Lock()
|
|
delete(mp.channels, ch)
|
|
mp.chLock.Unlock()
|
|
case closeTag:
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
msch.clLock.Lock()
|
|
|
|
if msch.closedRemote {
|
|
msch.clLock.Unlock()
|
|
// Technically a bug on the other side. We
|
|
// should consider killing the connection.
|
|
continue
|
|
}
|
|
|
|
close(msch.dataIn)
|
|
msch.closedRemote = true
|
|
|
|
cleanup := msch.isClosed()
|
|
|
|
msch.clLock.Unlock()
|
|
|
|
if cleanup {
|
|
msch.cancelDeadlines()
|
|
mp.chLock.Lock()
|
|
delete(mp.channels, ch)
|
|
mp.chLock.Unlock()
|
|
}
|
|
case messageTag:
|
|
if !ok {
|
|
// reset stream, return b
|
|
pool.Put(b)
|
|
|
|
// This is a perfectly valid case when we reset
|
|
// and forget about the stream.
|
|
log.Debugf("message for non-existant stream, dropping data: %d", ch)
|
|
// go mp.sendResetMsg(ch.header(resetTag), false)
|
|
continue
|
|
}
|
|
|
|
msch.clLock.Lock()
|
|
remoteClosed := msch.closedRemote
|
|
msch.clLock.Unlock()
|
|
if remoteClosed {
|
|
// closed stream, return b
|
|
pool.Put(b)
|
|
|
|
log.Warningf("Received data from remote after stream was closed by them. (len = %d)", len(b))
|
|
// go mp.sendResetMsg(msch.id.header(resetTag), false)
|
|
continue
|
|
}
|
|
|
|
recvTimeout.Reset(ReceiveTimeout)
|
|
select {
|
|
case msch.dataIn <- b:
|
|
case <-msch.reset:
|
|
pool.Put(b)
|
|
case <-recvTimeout.C:
|
|
pool.Put(b)
|
|
log.Warningf("timed out receiving message into stream queue.")
|
|
// Do not do this asynchronously. Otherwise, we
|
|
// could drop a message, then receive a message,
|
|
// then reset.
|
|
msch.Reset()
|
|
continue
|
|
case <-mp.shutdown:
|
|
pool.Put(b)
|
|
return
|
|
}
|
|
if !recvTimeout.Stop() {
|
|
<-recvTimeout.C
|
|
}
|
|
default:
|
|
log.Debugf("message with unknown header on stream %s", ch)
|
|
if ok {
|
|
msch.Reset()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mp *Multiplex) isShutdown() bool {
|
|
select {
|
|
case <-mp.shutdown:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (mp *Multiplex) sendResetMsg(header uint64, hard bool) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout)
|
|
defer cancel()
|
|
|
|
err := mp.sendMsg(ctx.Done(), header, nil)
|
|
if err != nil && !mp.isShutdown() {
|
|
if hard {
|
|
log.Warningf("error sending reset message: %s; killing connection", err.Error())
|
|
mp.Close()
|
|
} else {
|
|
log.Debugf("error sending reset message: %s", err.Error())
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mp *Multiplex) readNextHeader() (uint64, uint64, error) {
|
|
h, err := binary.ReadUvarint(mp.buf)
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
|
|
// get channel ID
|
|
ch := h >> 3
|
|
|
|
rem := h & 7
|
|
|
|
return ch, rem, nil
|
|
}
|
|
|
|
func (mp *Multiplex) readNext() ([]byte, error) {
|
|
// get length
|
|
l, err := binary.ReadUvarint(mp.buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if l > uint64(MaxMessageSize) {
|
|
return nil, fmt.Errorf("message size too large!")
|
|
}
|
|
|
|
if l == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
buf := pool.Get(int(l))
|
|
n, err := io.ReadFull(mp.buf, buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return buf[:n], nil
|
|
}
|
|
|
|
func isFatalNetworkError(err error) bool {
|
|
nerr, ok := err.(net.Error)
|
|
if ok {
|
|
return !(nerr.Timeout() || nerr.Temporary())
|
|
}
|
|
return false
|
|
}
|