2024-06-05 16:10:03 -04:00

218 lines
4.4 KiB
Go

package webtransport
import (
"errors"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/quic-go/quic-go"
)
const sessionCloseErrorCode quic.StreamErrorCode = 0x170d7b68
type SendStream interface {
io.Writer
io.Closer
StreamID() quic.StreamID
CancelWrite(StreamErrorCode)
SetWriteDeadline(time.Time) error
}
type ReceiveStream interface {
io.Reader
StreamID() quic.StreamID
CancelRead(StreamErrorCode)
SetReadDeadline(time.Time) error
}
type Stream interface {
SendStream
ReceiveStream
SetDeadline(time.Time) error
}
type sendStream struct {
str quic.SendStream
// WebTransport stream header.
// Set by the constructor, set to nil once sent out.
// Might be initialized to nil if this sendStream is part of an incoming bidirectional stream.
streamHdr []byte
onClose func()
once sync.Once
}
var _ SendStream = &sendStream{}
func newSendStream(str quic.SendStream, hdr []byte, onClose func()) *sendStream {
return &sendStream{str: str, streamHdr: hdr, onClose: onClose}
}
func (s *sendStream) maybeSendStreamHeader() (err error) {
s.once.Do(func() {
if _, e := s.str.Write(s.streamHdr); e != nil {
err = e
return
}
s.streamHdr = nil
})
return
}
func (s *sendStream) Write(b []byte) (int, error) {
if err := s.maybeSendStreamHeader(); err != nil {
return 0, err
}
n, err := s.str.Write(b)
if err != nil && !isTimeoutError(err) {
s.onClose()
}
return n, maybeConvertStreamError(err)
}
func (s *sendStream) CancelWrite(e StreamErrorCode) {
s.str.CancelWrite(webtransportCodeToHTTPCode(e))
s.onClose()
}
func (s *sendStream) closeWithSession() {
s.str.CancelWrite(sessionCloseErrorCode)
}
func (s *sendStream) Close() error {
if err := s.maybeSendStreamHeader(); err != nil {
return err
}
s.onClose()
return maybeConvertStreamError(s.str.Close())
}
func (s *sendStream) SetWriteDeadline(t time.Time) error {
return maybeConvertStreamError(s.str.SetWriteDeadline(t))
}
func (s *sendStream) StreamID() quic.StreamID {
return s.str.StreamID()
}
type receiveStream struct {
str quic.ReceiveStream
onClose func()
}
var _ ReceiveStream = &receiveStream{}
func newReceiveStream(str quic.ReceiveStream, onClose func()) *receiveStream {
return &receiveStream{str: str, onClose: onClose}
}
func (s *receiveStream) Read(b []byte) (int, error) {
n, err := s.str.Read(b)
if err != nil && !isTimeoutError(err) {
s.onClose()
}
return n, maybeConvertStreamError(err)
}
func (s *receiveStream) CancelRead(e StreamErrorCode) {
s.str.CancelRead(webtransportCodeToHTTPCode(e))
s.onClose()
}
func (s *receiveStream) closeWithSession() {
s.str.CancelRead(sessionCloseErrorCode)
}
func (s *receiveStream) SetReadDeadline(t time.Time) error {
return maybeConvertStreamError(s.str.SetReadDeadline(t))
}
func (s *receiveStream) StreamID() quic.StreamID {
return s.str.StreamID()
}
type stream struct {
*sendStream
*receiveStream
mx sync.Mutex
sendSideClosed, recvSideClosed bool
onClose func()
}
var _ Stream = &stream{}
func newStream(str quic.Stream, hdr []byte, onClose func()) *stream {
s := &stream{onClose: onClose}
s.sendStream = newSendStream(str, hdr, func() { s.registerClose(true) })
s.receiveStream = newReceiveStream(str, func() { s.registerClose(false) })
return s
}
func (s *stream) registerClose(isSendSide bool) {
s.mx.Lock()
if isSendSide {
s.sendSideClosed = true
} else {
s.recvSideClosed = true
}
isClosed := s.sendSideClosed && s.recvSideClosed
s.mx.Unlock()
if isClosed {
s.onClose()
}
}
func (s *stream) closeWithSession() {
s.sendStream.closeWithSession()
s.receiveStream.closeWithSession()
}
func (s *stream) SetDeadline(t time.Time) error {
err1 := s.sendStream.SetWriteDeadline(t)
err2 := s.receiveStream.SetReadDeadline(t)
if err1 != nil {
return err1
}
return err2
}
func (s *stream) StreamID() quic.StreamID {
return s.receiveStream.StreamID()
}
func maybeConvertStreamError(err error) error {
if err == nil {
return nil
}
var streamErr *quic.StreamError
if errors.As(err, &streamErr) {
errorCode, cerr := httpCodeToWebtransportCode(streamErr.ErrorCode)
if cerr != nil {
return fmt.Errorf("stream reset, but failed to convert stream error %d: %w", streamErr.ErrorCode, cerr)
}
return &StreamError{
ErrorCode: errorCode,
Remote: streamErr.Remote,
}
}
return err
}
func isTimeoutError(err error) bool {
nerr, ok := err.(net.Error)
if !ok {
return false
}
return nerr.Timeout()
}