2019-06-12 13:12:00 +02:00

474 lines
9.9 KiB
Go

package yamux
import (
"io"
"sync"
"sync/atomic"
"time"
"github.com/libp2p/go-buffer-pool"
)
type streamState int
const (
streamInit streamState = iota
streamSYNSent
streamSYNReceived
streamEstablished
streamLocalClose
streamRemoteClose
streamClosed
streamReset
)
// Stream is used to represent a logical stream
// within a session.
type Stream struct {
recvWindow uint32
sendWindow uint32
id uint32
session *Session
state streamState
stateLock sync.Mutex
recvLock sync.Mutex
recvBuf pool.Buffer
sendLock sync.Mutex
recvNotifyCh chan struct{}
sendNotifyCh chan struct{}
readDeadline, writeDeadline pipeDeadline
}
// newStream is used to construct a new stream within
// a given session for an ID
func newStream(session *Session, id uint32, state streamState) *Stream {
s := &Stream{
id: id,
session: session,
state: state,
recvWindow: initialStreamWindow,
sendWindow: initialStreamWindow,
readDeadline: makePipeDeadline(),
writeDeadline: makePipeDeadline(),
recvNotifyCh: make(chan struct{}, 1),
sendNotifyCh: make(chan struct{}, 1),
}
return s
}
// Session returns the associated stream session
func (s *Stream) Session() *Session {
return s.session
}
// StreamID returns the ID of this stream
func (s *Stream) StreamID() uint32 {
return s.id
}
// Read is used to read from the stream
func (s *Stream) Read(b []byte) (n int, err error) {
defer asyncNotify(s.recvNotifyCh)
START:
s.stateLock.Lock()
switch s.state {
case streamRemoteClose:
fallthrough
case streamClosed:
s.recvLock.Lock()
if s.recvBuf.Len() == 0 {
s.recvLock.Unlock()
s.stateLock.Unlock()
return 0, io.EOF
}
s.recvLock.Unlock()
case streamReset:
s.stateLock.Unlock()
return 0, ErrConnectionReset
}
s.stateLock.Unlock()
// If there is no data available, block
s.recvLock.Lock()
if s.recvBuf.Len() == 0 {
s.recvLock.Unlock()
goto WAIT
}
// Read any bytes
n, _ = s.recvBuf.Read(b)
s.recvLock.Unlock()
// Send a window update potentially
err = s.sendWindowUpdate()
return n, err
WAIT:
select {
case <-s.recvNotifyCh:
goto START
case <-s.readDeadline.wait():
return 0, ErrTimeout
}
}
// Write is used to write to the stream
func (s *Stream) Write(b []byte) (n int, err error) {
s.sendLock.Lock()
defer s.sendLock.Unlock()
total := 0
for total < len(b) {
n, err := s.write(b[total:])
total += n
if err != nil {
return total, err
}
}
return total, nil
}
// write is used to write to the stream, may return on
// a short write.
func (s *Stream) write(b []byte) (n int, err error) {
var flags uint16
var max uint32
var hdr header
START:
s.stateLock.Lock()
switch s.state {
case streamLocalClose:
fallthrough
case streamClosed:
s.stateLock.Unlock()
return 0, ErrStreamClosed
case streamReset:
s.stateLock.Unlock()
return 0, ErrConnectionReset
}
s.stateLock.Unlock()
// If there is no data available, block
window := atomic.LoadUint32(&s.sendWindow)
if window == 0 {
goto WAIT
}
// Determine the flags if any
flags = s.sendFlags()
// Send up to min(message, window
max = min(window, s.session.config.MaxMessageSize-headerSize, uint32(len(b)))
// Send the header
hdr = encode(typeData, flags, s.id, max)
if err = s.session.sendMsg(hdr, b[:max], s.writeDeadline.wait()); err != nil {
return 0, err
}
// Reduce our send window
atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
// Unlock
return int(max), err
WAIT:
select {
case <-s.sendNotifyCh:
goto START
case <-s.writeDeadline.wait():
return 0, ErrTimeout
}
}
// sendFlags determines any flags that are appropriate
// based on the current stream state
func (s *Stream) sendFlags() uint16 {
s.stateLock.Lock()
defer s.stateLock.Unlock()
var flags uint16
switch s.state {
case streamInit:
flags |= flagSYN
s.state = streamSYNSent
case streamSYNReceived:
flags |= flagACK
s.state = streamEstablished
}
return flags
}
// sendWindowUpdate potentially sends a window update enabling
// further writes to take place. Must be invoked with the lock.
func (s *Stream) sendWindowUpdate() error {
// Determine the delta update
max := s.session.config.MaxStreamWindowSize
s.recvLock.Lock()
delta := (max - uint32(s.recvBuf.Len())) - s.recvWindow
// Determine the flags if any
flags := s.sendFlags()
// Check if we can omit the update
if delta < (max/2) && flags == 0 {
s.recvLock.Unlock()
return nil
}
// Update our window
s.recvWindow += delta
s.recvLock.Unlock()
// Send the header
hdr := encode(typeWindowUpdate, flags, s.id, delta)
if err := s.session.sendMsg(hdr, nil, nil); err != nil {
return err
}
return nil
}
// sendClose is used to send a FIN
func (s *Stream) sendClose() error {
flags := s.sendFlags()
flags |= flagFIN
hdr := encode(typeWindowUpdate, flags, s.id, 0)
return s.session.sendMsg(hdr, nil, nil)
}
// sendReset is used to send a RST
func (s *Stream) sendReset() error {
hdr := encode(typeWindowUpdate, flagRST, s.id, 0)
return s.session.sendMsg(hdr, nil, nil)
}
// Reset resets the stream (forcibly closes the stream)
func (s *Stream) Reset() error {
s.stateLock.Lock()
switch s.state {
case streamInit:
// No need to send anything.
s.state = streamReset
s.stateLock.Unlock()
return nil
case streamClosed, streamReset:
s.stateLock.Unlock()
return nil
case streamSYNSent, streamSYNReceived, streamEstablished:
case streamLocalClose, streamRemoteClose:
default:
panic("unhandled state")
}
s.state = streamReset
s.stateLock.Unlock()
err := s.sendReset()
s.notifyWaiting()
s.cleanup()
return err
}
// Close is used to close the stream
func (s *Stream) Close() error {
closeStream := false
s.stateLock.Lock()
switch s.state {
case streamInit, streamSYNSent, streamSYNReceived, streamEstablished:
s.state = streamLocalClose
goto SEND_CLOSE
case streamLocalClose:
case streamRemoteClose:
s.state = streamClosed
closeStream = true
goto SEND_CLOSE
case streamClosed:
case streamReset:
default:
panic("unhandled state")
}
s.stateLock.Unlock()
return nil
SEND_CLOSE:
s.stateLock.Unlock()
err := s.sendClose()
s.notifyWaiting()
if closeStream {
s.cleanup()
}
return err
}
// forceClose is used for when the session is exiting
func (s *Stream) forceClose() {
s.stateLock.Lock()
switch s.state {
case streamClosed:
// Already successfully closed. It just hasn't been removed from
// the list of streams yet.
default:
s.state = streamReset
}
s.stateLock.Unlock()
s.notifyWaiting()
s.readDeadline.set(time.Time{})
s.readDeadline.set(time.Time{})
}
// called when fully closed to release any system resources.
func (s *Stream) cleanup() {
s.session.closeStream(s.id)
s.readDeadline.set(time.Time{})
s.readDeadline.set(time.Time{})
}
// processFlags is used to update the state of the stream
// based on set flags, if any. Lock must be held
func (s *Stream) processFlags(flags uint16) error {
// Close the stream without holding the state lock
closeStream := false
defer func() {
if closeStream {
s.cleanup()
}
}()
s.stateLock.Lock()
defer s.stateLock.Unlock()
if flags&flagACK == flagACK {
if s.state == streamSYNSent {
s.state = streamEstablished
}
s.session.establishStream(s.id)
}
if flags&flagFIN == flagFIN {
switch s.state {
case streamSYNSent:
fallthrough
case streamSYNReceived:
fallthrough
case streamEstablished:
s.state = streamRemoteClose
s.notifyWaiting()
case streamLocalClose:
s.state = streamClosed
closeStream = true
s.notifyWaiting()
default:
s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state)
return ErrUnexpectedFlag
}
}
if flags&flagRST == flagRST {
s.state = streamReset
closeStream = true
s.notifyWaiting()
}
return nil
}
// notifyWaiting notifies all the waiting channels
func (s *Stream) notifyWaiting() {
asyncNotify(s.recvNotifyCh)
asyncNotify(s.sendNotifyCh)
}
// incrSendWindow updates the size of our send window
func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
if err := s.processFlags(flags); err != nil {
return err
}
// Increase window, unblock a sender
atomic.AddUint32(&s.sendWindow, hdr.Length())
asyncNotify(s.sendNotifyCh)
return nil
}
// readData is used to handle a data frame
func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
if err := s.processFlags(flags); err != nil {
return err
}
// Check that our recv window is not exceeded
length := hdr.Length()
if length == 0 {
return nil
}
// Wrap in a limited reader
conn = &io.LimitedReader{R: conn, N: int64(length)}
// Copy into buffer
s.recvLock.Lock()
if length > s.recvWindow {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
return ErrRecvWindowExceeded
}
s.recvBuf.Grow(int(length))
if _, err := io.Copy(&s.recvBuf, conn); err != nil {
s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
s.recvLock.Unlock()
return err
}
// Decrement the receive window
s.recvWindow -= length
s.recvLock.Unlock()
// Unblock any readers
asyncNotify(s.recvNotifyCh)
return nil
}
// SetDeadline sets the read and write deadlines
func (s *Stream) SetDeadline(t time.Time) error {
if err := s.SetReadDeadline(t); err != nil {
return err
}
if err := s.SetWriteDeadline(t); err != nil {
return err
}
return nil
}
// SetReadDeadline sets the deadline for future Read calls.
func (s *Stream) SetReadDeadline(t time.Time) error {
s.stateLock.Lock()
defer s.stateLock.Unlock()
switch s.state {
case streamClosed, streamRemoteClose, streamReset:
return nil
}
s.readDeadline.set(t)
return nil
}
// SetWriteDeadline sets the deadline for future Write calls
func (s *Stream) SetWriteDeadline(t time.Time) error {
s.stateLock.Lock()
defer s.stateLock.Unlock()
switch s.state {
case streamClosed, streamLocalClose, streamReset:
return nil
}
s.writeDeadline.set(t)
return nil
}
// Shrink is a no-op. The internal buffer automatically shrinks itself.
func (s *Stream) Shrink() {
}