218 lines
4.4 KiB
Go
218 lines
4.4 KiB
Go
package webtransport
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/quic-go/quic-go"
|
|
)
|
|
|
|
const sessionCloseErrorCode quic.StreamErrorCode = 0x170d7b68
|
|
|
|
type SendStream interface {
|
|
io.Writer
|
|
io.Closer
|
|
|
|
StreamID() quic.StreamID
|
|
CancelWrite(StreamErrorCode)
|
|
|
|
SetWriteDeadline(time.Time) error
|
|
}
|
|
|
|
type ReceiveStream interface {
|
|
io.Reader
|
|
|
|
StreamID() quic.StreamID
|
|
CancelRead(StreamErrorCode)
|
|
|
|
SetReadDeadline(time.Time) error
|
|
}
|
|
|
|
type Stream interface {
|
|
SendStream
|
|
ReceiveStream
|
|
SetDeadline(time.Time) error
|
|
}
|
|
|
|
type sendStream struct {
|
|
str quic.SendStream
|
|
// WebTransport stream header.
|
|
// Set by the constructor, set to nil once sent out.
|
|
// Might be initialized to nil if this sendStream is part of an incoming bidirectional stream.
|
|
streamHdr []byte
|
|
|
|
onClose func()
|
|
|
|
once sync.Once
|
|
}
|
|
|
|
var _ SendStream = &sendStream{}
|
|
|
|
func newSendStream(str quic.SendStream, hdr []byte, onClose func()) *sendStream {
|
|
return &sendStream{str: str, streamHdr: hdr, onClose: onClose}
|
|
}
|
|
|
|
func (s *sendStream) maybeSendStreamHeader() (err error) {
|
|
s.once.Do(func() {
|
|
if _, e := s.str.Write(s.streamHdr); e != nil {
|
|
err = e
|
|
return
|
|
}
|
|
s.streamHdr = nil
|
|
})
|
|
return
|
|
}
|
|
|
|
func (s *sendStream) Write(b []byte) (int, error) {
|
|
if err := s.maybeSendStreamHeader(); err != nil {
|
|
return 0, err
|
|
}
|
|
n, err := s.str.Write(b)
|
|
if err != nil && !isTimeoutError(err) {
|
|
s.onClose()
|
|
}
|
|
return n, maybeConvertStreamError(err)
|
|
}
|
|
|
|
func (s *sendStream) CancelWrite(e StreamErrorCode) {
|
|
s.str.CancelWrite(webtransportCodeToHTTPCode(e))
|
|
s.onClose()
|
|
}
|
|
|
|
func (s *sendStream) closeWithSession() {
|
|
s.str.CancelWrite(sessionCloseErrorCode)
|
|
}
|
|
|
|
func (s *sendStream) Close() error {
|
|
if err := s.maybeSendStreamHeader(); err != nil {
|
|
return err
|
|
}
|
|
s.onClose()
|
|
return maybeConvertStreamError(s.str.Close())
|
|
}
|
|
|
|
func (s *sendStream) SetWriteDeadline(t time.Time) error {
|
|
return maybeConvertStreamError(s.str.SetWriteDeadline(t))
|
|
}
|
|
|
|
func (s *sendStream) StreamID() quic.StreamID {
|
|
return s.str.StreamID()
|
|
}
|
|
|
|
type receiveStream struct {
|
|
str quic.ReceiveStream
|
|
onClose func()
|
|
}
|
|
|
|
var _ ReceiveStream = &receiveStream{}
|
|
|
|
func newReceiveStream(str quic.ReceiveStream, onClose func()) *receiveStream {
|
|
return &receiveStream{str: str, onClose: onClose}
|
|
}
|
|
|
|
func (s *receiveStream) Read(b []byte) (int, error) {
|
|
n, err := s.str.Read(b)
|
|
if err != nil && !isTimeoutError(err) {
|
|
s.onClose()
|
|
}
|
|
return n, maybeConvertStreamError(err)
|
|
}
|
|
|
|
func (s *receiveStream) CancelRead(e StreamErrorCode) {
|
|
s.str.CancelRead(webtransportCodeToHTTPCode(e))
|
|
s.onClose()
|
|
}
|
|
|
|
func (s *receiveStream) closeWithSession() {
|
|
s.str.CancelRead(sessionCloseErrorCode)
|
|
}
|
|
|
|
func (s *receiveStream) SetReadDeadline(t time.Time) error {
|
|
return maybeConvertStreamError(s.str.SetReadDeadline(t))
|
|
}
|
|
|
|
func (s *receiveStream) StreamID() quic.StreamID {
|
|
return s.str.StreamID()
|
|
}
|
|
|
|
type stream struct {
|
|
*sendStream
|
|
*receiveStream
|
|
|
|
mx sync.Mutex
|
|
sendSideClosed, recvSideClosed bool
|
|
onClose func()
|
|
}
|
|
|
|
var _ Stream = &stream{}
|
|
|
|
func newStream(str quic.Stream, hdr []byte, onClose func()) *stream {
|
|
s := &stream{onClose: onClose}
|
|
s.sendStream = newSendStream(str, hdr, func() { s.registerClose(true) })
|
|
s.receiveStream = newReceiveStream(str, func() { s.registerClose(false) })
|
|
return s
|
|
}
|
|
|
|
func (s *stream) registerClose(isSendSide bool) {
|
|
s.mx.Lock()
|
|
if isSendSide {
|
|
s.sendSideClosed = true
|
|
} else {
|
|
s.recvSideClosed = true
|
|
}
|
|
isClosed := s.sendSideClosed && s.recvSideClosed
|
|
s.mx.Unlock()
|
|
|
|
if isClosed {
|
|
s.onClose()
|
|
}
|
|
}
|
|
|
|
func (s *stream) closeWithSession() {
|
|
s.sendStream.closeWithSession()
|
|
s.receiveStream.closeWithSession()
|
|
}
|
|
|
|
func (s *stream) SetDeadline(t time.Time) error {
|
|
err1 := s.sendStream.SetWriteDeadline(t)
|
|
err2 := s.receiveStream.SetReadDeadline(t)
|
|
if err1 != nil {
|
|
return err1
|
|
}
|
|
return err2
|
|
}
|
|
|
|
func (s *stream) StreamID() quic.StreamID {
|
|
return s.receiveStream.StreamID()
|
|
}
|
|
|
|
func maybeConvertStreamError(err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
var streamErr *quic.StreamError
|
|
if errors.As(err, &streamErr) {
|
|
errorCode, cerr := httpCodeToWebtransportCode(streamErr.ErrorCode)
|
|
if cerr != nil {
|
|
return fmt.Errorf("stream reset, but failed to convert stream error %d: %w", streamErr.ErrorCode, cerr)
|
|
}
|
|
return &StreamError{
|
|
ErrorCode: errorCode,
|
|
Remote: streamErr.Remote,
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func isTimeoutError(err error) bool {
|
|
nerr, ok := err.(net.Error)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return nerr.Timeout()
|
|
}
|