2024-05-15 19:15:00 -04:00

125 lines
3.2 KiB
Go

package webtransport
import (
"context"
"fmt"
"net/http"
"net/url"
"sync"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/quic-go/quicvarint"
)
type Dialer struct {
// If not set, reasonable defaults will be used.
// In order for WebTransport to function, this implementation will:
// * overwrite the StreamHijacker and UniStreamHijacker
// * enable datagram support
// * set the MaxIncomingStreams to 100 on the quic.Config, if unset
*http3.RoundTripper
// StreamReorderingTime is the time an incoming WebTransport stream that cannot be associated
// with a session is buffered.
// This can happen if the response to a CONNECT request (that creates a new session) is reordered,
// and arrives after the first WebTransport stream(s) for that session.
// Defaults to 5 seconds.
StreamReorderingTimeout time.Duration
ctx context.Context
ctxCancel context.CancelFunc
initOnce sync.Once
conns sessionManager
}
func (d *Dialer) init() {
timeout := d.StreamReorderingTimeout
if timeout == 0 {
timeout = 5 * time.Second
}
d.conns = *newSessionManager(timeout)
d.ctx, d.ctxCancel = context.WithCancel(context.Background())
if d.RoundTripper == nil {
d.RoundTripper = &http3.RoundTripper{}
}
d.RoundTripper.EnableDatagrams = true
if d.RoundTripper.AdditionalSettings == nil {
d.RoundTripper.AdditionalSettings = make(map[uint64]uint64)
}
d.RoundTripper.StreamHijacker = func(ft http3.FrameType, conn quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
if isWebTransportError(e) {
return true, nil
}
if ft != webTransportFrameType {
return false, nil
}
id, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if isWebTransportError(err) {
return true, nil
}
return false, err
}
d.conns.AddStream(conn, str, sessionID(id))
return true, nil
}
d.RoundTripper.UniStreamHijacker = func(st http3.StreamType, conn quic.Connection, str quic.ReceiveStream, err error) (hijacked bool) {
if st != webTransportUniStreamType && !isWebTransportError(err) {
return false
}
d.conns.AddUniStream(conn, str)
return true
}
if d.QuicConfig == nil {
d.QuicConfig = &quic.Config{}
}
if d.QuicConfig.MaxIncomingStreams == 0 {
d.QuicConfig.MaxIncomingStreams = 100
}
}
func (d *Dialer) Dial(ctx context.Context, urlStr string, reqHdr http.Header) (*http.Response, *Session, error) {
d.initOnce.Do(func() { d.init() })
u, err := url.Parse(urlStr)
if err != nil {
return nil, nil, err
}
if reqHdr == nil {
reqHdr = http.Header{}
}
reqHdr.Set(webTransportDraftOfferHeaderKey, "1")
req := &http.Request{
Method: http.MethodConnect,
Header: reqHdr,
Proto: "webtransport",
Host: u.Host,
URL: u,
}
req = req.WithContext(ctx)
rsp, err := d.RoundTripper.RoundTripOpt(req, http3.RoundTripOpt{DontCloseRequestStream: true})
if err != nil {
return nil, nil, err
}
if rsp.StatusCode < 200 || rsp.StatusCode >= 300 {
return rsp, nil, fmt.Errorf("received status %d", rsp.StatusCode)
}
str := rsp.Body.(http3.HTTPStreamer).HTTPStream()
conn := d.conns.AddSession(
rsp.Body.(http3.Hijacker).StreamCreator(),
sessionID(str.StreamID()),
str,
)
return rsp, conn, nil
}
func (d *Dialer) Close() error {
d.ctxCancel()
return nil
}