2022-04-06 10:36:06 -04:00

172 lines
4.9 KiB
Go

package quic
import (
"errors"
"sync"
"github.com/lucas-clemente/quic-go/internal/ackhandler"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/quicvarint"
)
type framer interface {
HasData() bool
QueueControlFrame(wire.Frame)
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount)
AddActiveStream(protocol.StreamID)
AppendStreamFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount)
Handle0RTTRejection() error
}
type framerI struct {
mutex sync.Mutex
streamGetter streamGetter
version protocol.VersionNumber
activeStreams map[protocol.StreamID]struct{}
streamQueue []protocol.StreamID
controlFrameMutex sync.Mutex
controlFrames []wire.Frame
}
var _ framer = &framerI{}
func newFramer(
streamGetter streamGetter,
v protocol.VersionNumber,
) framer {
return &framerI{
streamGetter: streamGetter,
activeStreams: make(map[protocol.StreamID]struct{}),
version: v,
}
}
func (f *framerI) HasData() bool {
f.mutex.Lock()
hasData := len(f.streamQueue) > 0
f.mutex.Unlock()
if hasData {
return true
}
f.controlFrameMutex.Lock()
hasData = len(f.controlFrames) > 0
f.controlFrameMutex.Unlock()
return hasData
}
func (f *framerI) QueueControlFrame(frame wire.Frame) {
f.controlFrameMutex.Lock()
f.controlFrames = append(f.controlFrames, frame)
f.controlFrameMutex.Unlock()
}
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) {
var length protocol.ByteCount
f.controlFrameMutex.Lock()
for len(f.controlFrames) > 0 {
frame := f.controlFrames[len(f.controlFrames)-1]
frameLen := frame.Length(f.version)
if length+frameLen > maxLen {
break
}
frames = append(frames, ackhandler.Frame{Frame: frame})
length += frameLen
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
}
f.controlFrameMutex.Unlock()
return frames, length
}
func (f *framerI) AddActiveStream(id protocol.StreamID) {
f.mutex.Lock()
if _, ok := f.activeStreams[id]; !ok {
f.streamQueue = append(f.streamQueue, id)
f.activeStreams[id] = struct{}{}
}
f.mutex.Unlock()
}
func (f *framerI) AppendStreamFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) {
var length protocol.ByteCount
var lastFrame *ackhandler.Frame
f.mutex.Lock()
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
numActiveStreams := len(f.streamQueue)
for i := 0; i < numActiveStreams; i++ {
if protocol.MinStreamFrameSize+length > maxLen {
break
}
id := f.streamQueue[0]
f.streamQueue = f.streamQueue[1:]
// This should never return an error. Better check it anyway.
// The stream will only be in the streamQueue, if it enqueued itself there.
str, err := f.streamGetter.GetOrOpenSendStream(id)
// The stream can be nil if it completed after it said it had data.
if str == nil || err != nil {
delete(f.activeStreams, id)
continue
}
remainingLen := maxLen - length
// For the last STREAM frame, we'll remove the DataLen field later.
// Therefore, we can pretend to have more bytes available when popping
// the STREAM frame (which will always have the DataLen set).
remainingLen += quicvarint.Len(uint64(remainingLen))
frame, hasMoreData := str.popStreamFrame(remainingLen)
if hasMoreData { // put the stream back in the queue (at the end)
f.streamQueue = append(f.streamQueue, id)
} else { // no more data to send. Stream is not active any more
delete(f.activeStreams, id)
}
// The frame can be nil
// * if the receiveStream was canceled after it said it had data
// * the remaining size doesn't allow us to add another STREAM frame
if frame == nil {
continue
}
frames = append(frames, *frame)
length += frame.Length(f.version)
lastFrame = frame
}
f.mutex.Unlock()
if lastFrame != nil {
lastFrameLen := lastFrame.Length(f.version)
// account for the smaller size of the last STREAM frame
lastFrame.Frame.(*wire.StreamFrame).DataLenPresent = false
length += lastFrame.Length(f.version) - lastFrameLen
}
return frames, length
}
func (f *framerI) Handle0RTTRejection() error {
f.mutex.Lock()
defer f.mutex.Unlock()
f.controlFrameMutex.Lock()
f.streamQueue = f.streamQueue[:0]
for id := range f.activeStreams {
delete(f.activeStreams, id)
}
var j int
for i, frame := range f.controlFrames {
switch frame.(type) {
case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame:
return errors.New("didn't expect MAX_DATA / MAX_STREAM_DATA / MAX_STREAMS frame to be sent in 0-RTT")
case *wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame:
continue
default:
f.controlFrames[j] = f.controlFrames[i]
j++
}
}
f.controlFrames = f.controlFrames[:j]
f.controlFrameMutex.Unlock()
return nil
}